mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 10:24:35 +00:00
llama : greatly reduce output buffer memory usage (#6122)
* llama : greatly reduce logits memory usage * llama : more compact state saving and reloading * llama : fix lctx.n_outputs not being set before building graph * perplexity : adapt to the logits API changes * perplexity : fix Winogrande, use correct logits for second choice start The first logits used to evaluate the second choice were not from the end of the common prefix; instead, they were the logits from the end of the first choice. This has been corrected. The previous implementation sometimes had outliers in the scores of choices for some tasks, and the logic to skip choices words in the log-likelihood evaluation probably was an attempt to reduce those, but it was complex and didn't quite seem to be the right thing. This is simpler now, and the outlier scores aren't there anymore. * perplexity : normalize spaces and punctuation in Winogrande sentences * llama : fix embedding conditions * llama : fix llama_get_embeddings_ith when the resulting id is 0 * llama : fix wrong n_outputs in llama_set_inputs A mismatch happened when using a smaller n_ubatch than n_batch and then using llama_batch_get_one(). The decision of what n_outputs should be now almost fully depends on how lctx.n_outputs is set in llama_decode_internal. The conditions are simpler this way. * llama : when saving the state, recalculate n_outputs This ensures the correct number of outputs for the entire previous batch is stored in the session file, even when n_ubatch is smaller than n_batch. * llama : fix not-skipping outputs of non-causal models * llama : fix running a batch with n_outputs == 0 It previously worked because lctx.inp_out_ids was not initialized, so it pointed to some garbage address which was somehow still valid when I ran my tests. * llama : keep same graph topology even when n_outputs == 0 * ggml : saner ggml_can_repeat with empty tensors * ggml : future-proof ggml_is_empty by using GGML_MAX_DIMS - 1 * ggml : do not multi-thread ops returning empty tensors * ggml : make ggml_is_empty public and work with views * llama : use a vector for ctx->output_ids * llama : rework reallocation logic for llama_output_reserve Now comparing the actual size with the new total size of the output buffer to allow more efficient enabling and disabling of the embeddings and/or logits output in the future. * ggml : skip empty tensors in all backends * llama : fix llama_output_reserve nullptr deref when new_size is 0 * perplexity : make Winogrande work as it does on master The problems with the Winogrande implementation will need to be fixed in a separate PR to ease review. * llama : clearer error messages for invalid logits or embeddings ids * llama : assert all models that can have inp_out_ids Since the graph topology is now constant, this presence check can be done even when there are no outputs. * llama : assert logits and embd buffers exist before writing to them * llama : handle errors from llama_output_reserve at call sites * perplexity : make hellaswag and multiple-choice outputs identical to master Due to how the KV cache is updated, the logprobs for tokens in a batch are very slightly affected by the other tokens present in the batch, so to make hellaswag and multiple-choice return exactly the same results as on master, the last token of each sequence needs to be evaluated even though its output is not used at all. This will probably be changed back in the future to make these benchmarks a tiny bit faster. * perplexity : fix division by zero when using less than 100 multiple-choice tasks * llama : allow loading state saved with a different ctx size When loading a session file, the context size is now only required to be at least enough to load the KV cells contained in that session file, instead of requiring to use exactly the same context size as when saving. Doing this enables the use-case of extending or shrinking the context size of a saved session. This breaks existing session files because the meaning of kv_buf_size is slightly changed (previously it was the size of the whole KV cache, now it's only the size of the saved part of it). This allows for finer-grained sanity checks when loading in an effort to keep kv_buf_size useful even when the kv_size is changed. * llama : minor ggml-ci * readme : update recent API changes, and warn about Vulkan --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
55c1b2a3bb
commit
557410b8f0
10
README.md
10
README.md
@ -10,6 +10,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
|
|||||||
|
|
||||||
### Recent API changes
|
### Recent API changes
|
||||||
|
|
||||||
|
- [2024 Mar 26] Logits and embeddings API updated for compactness https://github.com/ggerganov/llama.cpp/pull/6122
|
||||||
- [2024 Mar 13] Add `llama_synchronize()` + `llama_context_params.n_ubatch` https://github.com/ggerganov/llama.cpp/pull/6017
|
- [2024 Mar 13] Add `llama_synchronize()` + `llama_context_params.n_ubatch` https://github.com/ggerganov/llama.cpp/pull/6017
|
||||||
- [2024 Mar 8] `llama_kv_cache_seq_rm()` returns a `bool` instead of `void`, and new `llama_n_seq_max()` returns the upper limit of acceptable `seq_id` in batches (relevant when dealing with multiple sequences) https://github.com/ggerganov/llama.cpp/pull/5328
|
- [2024 Mar 8] `llama_kv_cache_seq_rm()` returns a `bool` instead of `void`, and new `llama_n_seq_max()` returns the upper limit of acceptable `seq_id` in batches (relevant when dealing with multiple sequences) https://github.com/ggerganov/llama.cpp/pull/5328
|
||||||
- [2024 Mar 4] Embeddings API updated https://github.com/ggerganov/llama.cpp/pull/5796
|
- [2024 Mar 4] Embeddings API updated https://github.com/ggerganov/llama.cpp/pull/5796
|
||||||
@ -630,6 +631,15 @@ Building the program with BLAS support may lead to some performance improvements
|
|||||||
|
|
||||||
- #### Vulkan
|
- #### Vulkan
|
||||||
|
|
||||||
|
> [!WARNING]
|
||||||
|
>
|
||||||
|
> Vulkan support has been broken in https://github.com/ggerganov/llama.cpp/pull/6122
|
||||||
|
> due to relying on `GGML_OP_GET_ROWS` which is not yet properly supported by the Vulkan backend,
|
||||||
|
> but should be fixed relatively soon (possibly in https://github.com/ggerganov/llama.cpp/pull/6155
|
||||||
|
> (ref: https://github.com/ggerganov/llama.cpp/pull/6122#issuecomment-2015327635)).
|
||||||
|
>
|
||||||
|
> Meanwhile, if you want to use the Vulkan backend, you should use the commit right before the breaking change, https://github.com/ggerganov/llama.cpp/commit/55c1b2a3bbd470e9e2a3a0618b92cf64a885f806
|
||||||
|
|
||||||
**With docker**:
|
**With docker**:
|
||||||
|
|
||||||
You don't need to install Vulkan SDK. It will be installed inside the container.
|
You don't need to install Vulkan SDK. It will be installed inside the container.
|
||||||
|
@ -424,6 +424,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool
|
|||||||
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
|
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: use batch.logits to save computations instead of relying on logits_all == true
|
||||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
|
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
|
@ -132,7 +132,6 @@ int main(int argc, char ** argv) {
|
|||||||
llama_context * ctx = NULL;
|
llama_context * ctx = NULL;
|
||||||
|
|
||||||
// load the target model
|
// load the target model
|
||||||
params.logits_all = true;
|
|
||||||
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
||||||
|
|
||||||
// load the prompts from an external file if there are any
|
// load the prompts from an external file if there are any
|
||||||
|
@ -380,6 +380,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
|
|||||||
const int batch_size = std::min(end - batch_start, n_batch);
|
const int batch_size = std::min(end - batch_start, n_batch);
|
||||||
|
|
||||||
//fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
|
//fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
|
||||||
|
// TODO: use llama_batch.logits instead of relying on logits_all == true
|
||||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
|
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
|
||||||
//fprintf(stderr, "%s : failed to eval\n", __func__);
|
//fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return {tokens, -1, logit_history, prob_history};
|
return {tokens, -1, logit_history, prob_history};
|
||||||
@ -552,6 +553,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||||||
const int batch_start = start + j * n_batch;
|
const int batch_start = start + j * n_batch;
|
||||||
const int batch_size = std::min(end - batch_start, n_batch);
|
const int batch_size = std::min(end - batch_start, n_batch);
|
||||||
|
|
||||||
|
int n_outputs = 0;
|
||||||
|
|
||||||
batch.n_tokens = 0;
|
batch.n_tokens = 0;
|
||||||
for (int seq = 0; seq < n_seq_batch; seq++) {
|
for (int seq = 0; seq < n_seq_batch; seq++) {
|
||||||
int seq_start = batch_start + seq*n_ctx;
|
int seq_start = batch_start + seq*n_ctx;
|
||||||
@ -566,11 +569,13 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||||||
|
|
||||||
for (int k = 0; k < batch_size; ++k) {
|
for (int k = 0; k < batch_size; ++k) {
|
||||||
const int idx = seq*n_ctx + k;
|
const int idx = seq*n_ctx + k;
|
||||||
batch.token[idx] = tokens[seq_start + k];
|
batch.token [idx] = tokens[seq_start + k];
|
||||||
batch.pos[idx] = j*n_batch + k;
|
batch.pos [idx] = j*n_batch + k;
|
||||||
batch.n_seq_id[idx] = 1;
|
batch.n_seq_id[idx] = 1;
|
||||||
batch.seq_id[idx][0] = seq;
|
batch.seq_id [idx][0] = seq;
|
||||||
batch.logits[idx] = batch.pos[idx] >= first ? 1 : 0;
|
batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0;
|
||||||
|
|
||||||
|
n_outputs += batch.logits[idx] != 0;
|
||||||
}
|
}
|
||||||
batch.n_tokens += batch_size;
|
batch.n_tokens += batch_size;
|
||||||
|
|
||||||
@ -583,9 +588,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||||||
return {tokens, -1, logit_history, prob_history};
|
return {tokens, -1, logit_history, prob_history};
|
||||||
}
|
}
|
||||||
|
|
||||||
if (num_batches > 1) {
|
if (num_batches > 1 && n_outputs > 0) {
|
||||||
const auto * batch_logits = llama_get_logits(ctx);
|
const auto * batch_logits = llama_get_logits(ctx);
|
||||||
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
|
logits.insert(logits.end(), batch_logits, batch_logits + n_outputs * n_vocab);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -604,14 +609,15 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (int seq = 0; seq < n_seq_batch; seq++) {
|
for (int seq = 0; seq < n_seq_batch; seq++) {
|
||||||
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx);
|
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
|
||||||
|
|
||||||
llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
|
llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
|
||||||
if (!params.logits_file.empty()) {
|
if (!params.logits_file.empty()) {
|
||||||
process_logits(logits_stream, n_vocab, all_logits + first*n_vocab,
|
process_logits(logits_stream, n_vocab, all_logits,
|
||||||
tokens_data, n_ctx - 1 - first,
|
tokens_data, n_ctx - 1 - first,
|
||||||
workers, log_probs, nll, nll2);
|
workers, log_probs, nll, nll2);
|
||||||
} else {
|
} else {
|
||||||
process_logits(n_vocab, all_logits + first*n_vocab,
|
process_logits(n_vocab, all_logits,
|
||||||
tokens_data, n_ctx - 1 - first,
|
tokens_data, n_ctx - 1 - first,
|
||||||
workers, nll, nll2,
|
workers, nll, nll2,
|
||||||
logit_history.data() + start + seq*n_ctx + first,
|
logit_history.data() + start + seq*n_ctx + first,
|
||||||
@ -652,6 +658,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||||||
}
|
}
|
||||||
|
|
||||||
static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int32_t n_batch, int32_t n_vocab) {
|
static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int32_t n_batch, int32_t n_vocab) {
|
||||||
|
int prev_outputs = 0;
|
||||||
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
|
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
|
||||||
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
||||||
|
|
||||||
@ -672,7 +679,14 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
memcpy(batch_logits.data() + i*n_vocab, llama_get_logits(ctx), n_tokens*n_vocab*sizeof(float));
|
int n_outputs = 0;
|
||||||
|
for (int i = 0; i < n_tokens; ++i) {
|
||||||
|
n_outputs += batch_view.logits[i] != 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
memcpy(batch_logits.data() + prev_outputs*n_vocab, llama_get_logits(ctx), n_outputs*n_vocab*sizeof(float));
|
||||||
|
|
||||||
|
prev_outputs += n_outputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
@ -779,7 +793,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
size_t ending_logprob_count[4];
|
size_t ending_logprob_count[4];
|
||||||
double ending_logprob[4];
|
double ending_logprob[4];
|
||||||
|
|
||||||
size_t i_batch; // starting index in the llama_batch
|
size_t i_logits; // starting index of logits in the llama_batch
|
||||||
size_t common_prefix; // max number of initial tokens that are the same in all sentences
|
size_t common_prefix; // max number of initial tokens that are the same in all sentences
|
||||||
size_t required_tokens; // needed number of tokens to evaluate all 4 endings
|
size_t required_tokens; // needed number of tokens to evaluate all 4 endings
|
||||||
std::vector<llama_token> seq_tokens[4];
|
std::vector<llama_token> seq_tokens[4];
|
||||||
@ -844,9 +858,10 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
const int max_tasks_per_batch = 32;
|
const int max_tasks_per_batch = 32;
|
||||||
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
|
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
|
||||||
|
|
||||||
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
|
llama_batch batch = llama_batch_init(n_ctx, 0, 4);
|
||||||
|
|
||||||
std::vector<float> tok_logits(n_vocab);
|
std::vector<float> tok_logits(n_vocab);
|
||||||
|
// TODO: this could be made smaller; it's currently the worst-case size
|
||||||
std::vector<float> batch_logits(n_vocab*n_ctx);
|
std::vector<float> batch_logits(n_vocab*n_ctx);
|
||||||
|
|
||||||
std::vector<std::pair<size_t, llama_token>> eval_pairs;
|
std::vector<std::pair<size_t, llama_token>> eval_pairs;
|
||||||
@ -857,16 +872,17 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
int n_cur = 0;
|
int n_cur = 0;
|
||||||
|
|
||||||
size_t i1 = i0;
|
size_t i1 = i0;
|
||||||
size_t i_batch = 0; // this tells us where in `llama_batch` we are currently
|
size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
|
||||||
|
|
||||||
llama_batch_clear(batch);
|
llama_batch_clear(batch);
|
||||||
|
|
||||||
// batch as much tasks as possible into the available context
|
// batch as much tasks as possible into the available context
|
||||||
// each task has 4 unique seuqnce ids - one for each ending
|
// each task has 4 unique sequence ids - one for each ending
|
||||||
// the common prefix is shared among the 4 sequences to save tokens
|
// the common prefix is shared among the 4 sequences to save tokens
|
||||||
// we extract logits only from the last common token and from all ending tokens of each sequence
|
// we extract logits only from the last common token and from all ending tokens of each sequence
|
||||||
while (n_cur + (int) hs_data[i1].required_tokens <= n_ctx) {
|
while (n_cur + (int) hs_data[i1].required_tokens <= n_ctx) {
|
||||||
auto & hs_cur = hs_data[i1];
|
auto & hs_cur = hs_data[i1];
|
||||||
|
int n_logits = 0;
|
||||||
|
|
||||||
const int s0 = 4*(i1 - i0);
|
const int s0 = 4*(i1 - i0);
|
||||||
if (s0 + 4 > max_seq) {
|
if (s0 + 4 > max_seq) {
|
||||||
@ -874,18 +890,23 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
|
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
|
||||||
llama_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
|
llama_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
|
||||||
}
|
}
|
||||||
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
|
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
|
||||||
|
n_logits += 1;
|
||||||
|
|
||||||
for (int s = 0; s < 4; ++s) {
|
for (int s = 0; s < 4; ++s) {
|
||||||
for (size_t i = hs_cur.common_prefix; i < hs_cur.seq_tokens[s].size(); ++i) {
|
const size_t seq_tokens_size = hs_cur.seq_tokens[s].size();
|
||||||
llama_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, true);
|
// TODO: don't evaluate the last token of each sequence
|
||||||
|
for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) {
|
||||||
|
const bool needs_logits = i < seq_tokens_size - 1;
|
||||||
|
llama_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits);
|
||||||
|
n_logits += needs_logits;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
hs_cur.i_batch = i_batch;
|
hs_cur.i_logits = i_logits;
|
||||||
i_batch += hs_cur.required_tokens;
|
i_logits += n_logits;
|
||||||
|
|
||||||
n_cur += hs_data[i1].required_tokens;
|
n_cur += hs_data[i1].required_tokens;
|
||||||
if (++i1 == hs_task_count) {
|
if (++i1 == hs_task_count) {
|
||||||
@ -911,12 +932,11 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
eval_pairs.clear();
|
eval_pairs.clear();
|
||||||
for (size_t i = i0; i < i1; ++i) {
|
for (size_t i = i0; i < i1; ++i) {
|
||||||
auto & hs_cur = hs_data[i];
|
auto & hs_cur = hs_data[i];
|
||||||
size_t li = hs_cur.common_prefix;
|
size_t li = 1; // skip the last logit of the common prefix (computed separately below)
|
||||||
for (int s = 0; s < 4; ++s) {
|
for (int s = 0; s < 4; ++s) {
|
||||||
for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) {
|
for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) {
|
||||||
eval_pairs.emplace_back(hs_cur.i_batch + li++, hs_cur.seq_tokens[s][j + 1]);
|
eval_pairs.emplace_back(hs_cur.i_logits + li++, hs_cur.seq_tokens[s][j + 1]);
|
||||||
}
|
}
|
||||||
++li;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Then we do the actual calculation
|
// Then we do the actual calculation
|
||||||
@ -928,7 +948,8 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
for (size_t i = i0; i < i1; ++i) {
|
for (size_t i = i0; i < i1; ++i) {
|
||||||
auto & hs_cur = hs_data[i];
|
auto & hs_cur = hs_data[i];
|
||||||
|
|
||||||
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(hs_cur.i_batch + hs_cur.common_prefix - 1), n_vocab*sizeof(float));
|
// get the logits of the last token of the common prefix
|
||||||
|
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*hs_cur.i_logits, n_vocab*sizeof(float));
|
||||||
|
|
||||||
const auto first_probs = softmax(tok_logits);
|
const auto first_probs = softmax(tok_logits);
|
||||||
|
|
||||||
@ -978,7 +999,7 @@ struct winogrande_entry {
|
|||||||
std::array<std::string, 2> choices;
|
std::array<std::string, 2> choices;
|
||||||
int answer;
|
int answer;
|
||||||
|
|
||||||
size_t i_batch;
|
size_t i_logits;
|
||||||
size_t common_prefix;
|
size_t common_prefix;
|
||||||
size_t required_tokens;
|
size_t required_tokens;
|
||||||
size_t n_base1; // number of tokens for context + choice 1
|
size_t n_base1; // number of tokens for context + choice 1
|
||||||
@ -1104,6 +1125,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
task.common_prefix++;
|
task.common_prefix++;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: the last token of each of the sequences don't need to be evaluated
|
||||||
task.required_tokens = task.common_prefix +
|
task.required_tokens = task.common_prefix +
|
||||||
task.seq_tokens[0].size() - task.common_prefix +
|
task.seq_tokens[0].size() - task.common_prefix +
|
||||||
task.seq_tokens[1].size() - task.common_prefix;
|
task.seq_tokens[1].size() - task.common_prefix;
|
||||||
@ -1121,9 +1143,10 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
const int max_tasks_per_batch = 128;
|
const int max_tasks_per_batch = 128;
|
||||||
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
|
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
|
||||||
|
|
||||||
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
|
llama_batch batch = llama_batch_init(n_ctx, 0, 2);
|
||||||
|
|
||||||
std::vector<float> tok_logits(n_vocab);
|
std::vector<float> tok_logits(n_vocab);
|
||||||
|
// TODO: this could be made smaller; it's currently the worst-case size
|
||||||
std::vector<float> batch_logits(n_vocab*n_ctx);
|
std::vector<float> batch_logits(n_vocab*n_ctx);
|
||||||
|
|
||||||
std::vector<std::pair<size_t, llama_token>> eval_pairs;
|
std::vector<std::pair<size_t, llama_token>> eval_pairs;
|
||||||
@ -1137,29 +1160,33 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
int n_cur = 0;
|
int n_cur = 0;
|
||||||
|
|
||||||
size_t i1 = i0;
|
size_t i1 = i0;
|
||||||
size_t i_batch = 0;
|
size_t i_logits = 0;
|
||||||
|
|
||||||
llama_batch_clear(batch);
|
llama_batch_clear(batch);
|
||||||
|
|
||||||
while (n_cur + (int) data[i1].required_tokens <= n_ctx) {
|
while (n_cur + (int) data[i1].required_tokens <= n_ctx) {
|
||||||
|
int n_logits = 0;
|
||||||
const int s0 = 2*(i1 - i0);
|
const int s0 = 2*(i1 - i0);
|
||||||
if (s0 + 2 > max_seq) {
|
if (s0 + 2 > max_seq) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
|
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
|
||||||
llama_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1}, false);
|
llama_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
|
||||||
}
|
}
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
batch.logits[batch.n_tokens - 1] = true;
|
||||||
|
n_logits += 1;
|
||||||
|
|
||||||
for (int s = 0; s < 2; ++s) {
|
for (int s = 0; s < 2; ++s) {
|
||||||
|
// TODO: end before the last token, no need to predict past the end of the sequences
|
||||||
for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
|
for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
|
||||||
llama_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true);
|
llama_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true);
|
||||||
|
n_logits += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
data[i1].i_batch = i_batch;
|
data[i1].i_logits = i_logits;
|
||||||
i_batch += data[i1].required_tokens;
|
i_logits += n_logits;
|
||||||
|
|
||||||
n_cur += data[i1].required_tokens;
|
n_cur += data[i1].required_tokens;
|
||||||
if (++i1 == data.size()) {
|
if (++i1 == data.size()) {
|
||||||
@ -1190,15 +1217,16 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
|
|
||||||
const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
|
const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
|
||||||
const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
|
const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
|
||||||
size_t li = n_base1 - 1;
|
size_t li = n_base1 - task.common_prefix;
|
||||||
for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
|
for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
|
||||||
eval_pairs.emplace_back(task.i_batch + li++, task.seq_tokens[0][j+1]);
|
eval_pairs.emplace_back(task.i_logits + li++, task.seq_tokens[0][j+1]);
|
||||||
}
|
}
|
||||||
const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
|
const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
|
||||||
const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
|
const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
|
||||||
li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - 1;
|
// FIXME: this uses the wrong first logits when not skipping the choice word
|
||||||
|
li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - task.common_prefix;
|
||||||
for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
|
for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
|
||||||
eval_pairs.emplace_back(task.i_batch + li++, task.seq_tokens[1][j+1]);
|
eval_pairs.emplace_back(task.i_logits + li++, task.seq_tokens[1][j+1]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
|
compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
|
||||||
@ -1287,7 +1315,7 @@ struct multiple_choice_task {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// For evaluation
|
// For evaluation
|
||||||
size_t i_batch; // starting index in the llama_batch
|
size_t i_logits; // starting index of logits in the llama_batch
|
||||||
size_t common_prefix; // max number of initial tokens that are the same in all sentences
|
size_t common_prefix; // max number of initial tokens that are the same in all sentences
|
||||||
size_t required_tokens; // needed number of tokens to evaluate all answers
|
size_t required_tokens; // needed number of tokens to evaluate all answers
|
||||||
std::vector<std::vector<llama_token>> seq_tokens;
|
std::vector<std::vector<llama_token>> seq_tokens;
|
||||||
@ -1366,7 +1394,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
|||||||
std::vector<uint32_t> task_pos(n_task);
|
std::vector<uint32_t> task_pos(n_task);
|
||||||
strstream.read((char *)task_pos.data(), task_pos.size()*sizeof(uint32_t));
|
strstream.read((char *)task_pos.data(), task_pos.size()*sizeof(uint32_t));
|
||||||
if (strstream.fail()) {
|
if (strstream.fail()) {
|
||||||
printf("%s: failed to raad task positions from prompt\n", __func__);
|
printf("%s: failed to read task positions from prompt\n", __func__);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1447,7 +1475,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
int n_dot = n_task/100;
|
int n_dot = std::max((int) n_task/100, 1);
|
||||||
int i_task = 0;
|
int i_task = 0;
|
||||||
for (auto& task : tasks) {
|
for (auto& task : tasks) {
|
||||||
++i_task;
|
++i_task;
|
||||||
@ -1491,17 +1519,18 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
|||||||
int n_cur = 0;
|
int n_cur = 0;
|
||||||
|
|
||||||
size_t i1 = i0;
|
size_t i1 = i0;
|
||||||
size_t i_batch = 0; // this tells us where in `llama_batch` we are currently
|
size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
|
||||||
|
|
||||||
llama_batch_clear(batch);
|
llama_batch_clear(batch);
|
||||||
|
|
||||||
// batch as much tasks as possible into the available context
|
// batch as much tasks as possible into the available context
|
||||||
// each task has 4 unique seuqnce ids - one for each ending
|
// each task has 4 unique sequence ids - one for each ending
|
||||||
// the common prefix is shared among the 4 sequences to save tokens
|
// the common prefix is shared among the 4 sequences to save tokens
|
||||||
// we extract logits only from the last common token and from all ending tokens of each sequence
|
// we extract logits only from the last common token and from all ending tokens of each sequence
|
||||||
int s0 = 0;
|
int s0 = 0;
|
||||||
while (n_cur + (int) tasks[i1].required_tokens <= n_ctx) {
|
while (n_cur + (int) tasks[i1].required_tokens <= n_ctx) {
|
||||||
auto& cur_task = tasks[i1];
|
auto& cur_task = tasks[i1];
|
||||||
|
int n_logits = 0;
|
||||||
|
|
||||||
int num_answers = cur_task.seq_tokens.size();
|
int num_answers = cur_task.seq_tokens.size();
|
||||||
if (s0 + num_answers > max_seq) {
|
if (s0 + num_answers > max_seq) {
|
||||||
@ -1518,17 +1547,22 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
|||||||
llama_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false);
|
llama_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false);
|
||||||
}
|
}
|
||||||
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
|
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
|
||||||
|
n_logits += 1;
|
||||||
|
|
||||||
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
|
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
|
||||||
for (size_t i = cur_task.common_prefix; i < cur_task.seq_tokens[s].size(); ++i) {
|
const size_t seq_tokens_size = cur_task.seq_tokens[s].size();
|
||||||
llama_batch_add(batch, cur_task.seq_tokens[s][i], i, { s0 + s }, true);
|
// TODO: don't evaluate the last token of each sequence
|
||||||
|
for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) {
|
||||||
|
const bool needs_logits = i < seq_tokens_size - 1;
|
||||||
|
llama_batch_add(batch, cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits);
|
||||||
|
n_logits += needs_logits;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s0 += num_answers;
|
s0 += num_answers;
|
||||||
|
|
||||||
cur_task.i_batch = i_batch;
|
cur_task.i_logits = i_logits;
|
||||||
i_batch += cur_task.required_tokens;
|
i_logits += n_logits;
|
||||||
|
|
||||||
n_cur += cur_task.required_tokens;
|
n_cur += cur_task.required_tokens;
|
||||||
if (++i1 == tasks.size()) {
|
if (++i1 == tasks.size()) {
|
||||||
@ -1554,12 +1588,11 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
|||||||
eval_pairs.clear();
|
eval_pairs.clear();
|
||||||
for (size_t i = i0; i < i1; ++i) {
|
for (size_t i = i0; i < i1; ++i) {
|
||||||
auto& cur_task = tasks[i];
|
auto& cur_task = tasks[i];
|
||||||
size_t li = cur_task.common_prefix;
|
size_t li = 1; // skip the last logit of the common prefix (computed separately below)
|
||||||
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
|
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
|
||||||
for (size_t j = cur_task.common_prefix; j < cur_task.seq_tokens[s].size() - 1; j++) {
|
for (size_t j = cur_task.common_prefix; j < cur_task.seq_tokens[s].size() - 1; j++) {
|
||||||
eval_pairs.emplace_back(cur_task.i_batch + li++, cur_task.seq_tokens[s][j + 1]);
|
eval_pairs.emplace_back(cur_task.i_logits + li++, cur_task.seq_tokens[s][j + 1]);
|
||||||
}
|
}
|
||||||
++li;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Then we do the actual calculation
|
// Then we do the actual calculation
|
||||||
@ -1578,7 +1611,8 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
|||||||
//}
|
//}
|
||||||
//printf("\n common_prefix: %zu\n", cur_task.common_prefix);
|
//printf("\n common_prefix: %zu\n", cur_task.common_prefix);
|
||||||
|
|
||||||
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(cur_task.i_batch + cur_task.common_prefix - 1), n_vocab*sizeof(float));
|
// get the logits of the last token of the common prefix
|
||||||
|
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*cur_task.i_logits, n_vocab*sizeof(float));
|
||||||
|
|
||||||
const auto first_probs = softmax(tok_logits);
|
const auto first_probs = softmax(tok_logits);
|
||||||
|
|
||||||
@ -1730,6 +1764,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
|
|||||||
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
|
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: use llama_batch.logits instead of relying on logits_all == true
|
||||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
|
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return;
|
return;
|
||||||
|
@ -747,7 +747,8 @@ struct server_context {
|
|||||||
{
|
{
|
||||||
const int32_t n_batch = llama_n_batch(ctx);
|
const int32_t n_batch = llama_n_batch(ctx);
|
||||||
|
|
||||||
batch = llama_batch_init(n_batch, 0, params.n_parallel);
|
// only a single seq_id per token is needed
|
||||||
|
batch = llama_batch_init(n_batch, 0, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
metrics.init();
|
metrics.init();
|
||||||
|
@ -65,7 +65,6 @@ int main(int argc, char ** argv) {
|
|||||||
llama_context * ctx_dft = NULL;
|
llama_context * ctx_dft = NULL;
|
||||||
|
|
||||||
// load the target model
|
// load the target model
|
||||||
params.logits_all = true;
|
|
||||||
std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params);
|
std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params);
|
||||||
|
|
||||||
// load the draft model
|
// load the draft model
|
||||||
|
@ -2505,7 +2505,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|||||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||||
ggml_tensor * node = cgraph->nodes[i];
|
ggml_tensor * node = cgraph->nodes[i];
|
||||||
|
|
||||||
if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1430,6 +1430,10 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
|||||||
struct ggml_tensor * dst = gf->nodes[i];
|
struct ggml_tensor * dst = gf->nodes[i];
|
||||||
GGML_ASSERT(dst->data != nullptr);
|
GGML_ASSERT(dst->data != nullptr);
|
||||||
|
|
||||||
|
if (ggml_is_empty(dst)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
switch (dst->op) {
|
switch (dst->op) {
|
||||||
case GGML_OP_NONE:
|
case GGML_OP_NONE:
|
||||||
case GGML_OP_RESHAPE:
|
case GGML_OP_RESHAPE:
|
||||||
|
@ -847,6 +847,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
|
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
|
||||||
struct ggml_tensor * dst = gf->nodes[i];
|
struct ggml_tensor * dst = gf->nodes[i];
|
||||||
|
|
||||||
|
if (ggml_is_empty(dst)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
switch (dst->op) {
|
switch (dst->op) {
|
||||||
case GGML_OP_NONE:
|
case GGML_OP_NONE:
|
||||||
case GGML_OP_RESHAPE:
|
case GGML_OP_RESHAPE:
|
||||||
|
@ -2234,6 +2234,11 @@ static ggml_backend_buffer_type_t ggml_backend_opencl_get_default_buffer_type(gg
|
|||||||
static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * graph) {
|
static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * graph) {
|
||||||
for (int i = 0; i < graph->n_nodes; ++i) {
|
for (int i = 0; i < graph->n_nodes; ++i) {
|
||||||
ggml_tensor * node = graph->nodes[i];
|
ggml_tensor * node = graph->nodes[i];
|
||||||
|
|
||||||
|
if (ggml_is_empty(node)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
switch (node->op) {
|
switch (node->op) {
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
ggml_cl_mul_mat(node->src[0], node->src[1], node, nullptr, 0);
|
ggml_cl_mul_mat(node->src[0], node->src[1], node, nullptr, 0);
|
||||||
|
@ -16973,7 +16973,7 @@ GGML_CALL static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t back
|
|||||||
params.ith = 0;
|
params.ith = 0;
|
||||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||||
ggml_tensor * node = cgraph->nodes[i];
|
ggml_tensor * node = cgraph->nodes[i];
|
||||||
if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
|
@ -5566,7 +5566,7 @@ GGML_CALL static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backen
|
|||||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||||
ggml_tensor * node = cgraph->nodes[i];
|
ggml_tensor * node = cgraph->nodes[i];
|
||||||
|
|
||||||
if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
20
ggml.c
20
ggml.c
@ -2607,6 +2607,16 @@ static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
|
|||||||
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GGML_CALL bool ggml_is_empty(const struct ggml_tensor * tensor) {
|
||||||
|
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
|
||||||
|
if (tensor->ne[i] == 0) {
|
||||||
|
// empty if any dimension has no elements
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
||||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||||
|
|
||||||
@ -2621,7 +2631,7 @@ bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor
|
|||||||
static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
||||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||||
|
|
||||||
return
|
return ggml_is_empty(t0) ? ggml_is_empty(t1) :
|
||||||
(t1->ne[0]%t0->ne[0] == 0) &&
|
(t1->ne[0]%t0->ne[0] == 0) &&
|
||||||
(t1->ne[1]%t0->ne[1] == 0) &&
|
(t1->ne[1]%t0->ne[1] == 0) &&
|
||||||
(t1->ne[2]%t0->ne[2] == 0) &&
|
(t1->ne[2]%t0->ne[2] == 0) &&
|
||||||
@ -16114,7 +16124,7 @@ static void ggml_compute_forward_cross_entropy_loss_back(
|
|||||||
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
|
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
|
||||||
GGML_ASSERT(params);
|
GGML_ASSERT(params);
|
||||||
|
|
||||||
if (tensor->op == GGML_OP_NONE) {
|
if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -17983,6 +17993,12 @@ static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const
|
|||||||
static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_threads) {
|
static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_threads) {
|
||||||
int n_tasks = 0;
|
int n_tasks = 0;
|
||||||
|
|
||||||
|
if (ggml_is_empty(node)) {
|
||||||
|
// no need to multi-thread a no-op
|
||||||
|
n_tasks = 1;
|
||||||
|
return n_tasks;
|
||||||
|
}
|
||||||
|
|
||||||
switch (node->op) {
|
switch (node->op) {
|
||||||
case GGML_OP_CPY:
|
case GGML_OP_CPY:
|
||||||
case GGML_OP_DUP:
|
case GGML_OP_DUP:
|
||||||
|
1
ggml.h
1
ggml.h
@ -750,6 +750,7 @@ extern "C" {
|
|||||||
GGML_API GGML_CALL bool ggml_is_transposed(const struct ggml_tensor * tensor);
|
GGML_API GGML_CALL bool ggml_is_transposed(const struct ggml_tensor * tensor);
|
||||||
GGML_API GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor);
|
GGML_API GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor);
|
||||||
GGML_API GGML_CALL bool ggml_is_permuted (const struct ggml_tensor * tensor);
|
GGML_API GGML_CALL bool ggml_is_permuted (const struct ggml_tensor * tensor);
|
||||||
|
GGML_API GGML_CALL bool ggml_is_empty (const struct ggml_tensor * tensor);
|
||||||
GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor);
|
GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor);
|
||||||
GGML_API bool ggml_is_vector (const struct ggml_tensor * tensor);
|
GGML_API bool ggml_is_vector (const struct ggml_tensor * tensor);
|
||||||
GGML_API bool ggml_is_matrix (const struct ggml_tensor * tensor);
|
GGML_API bool ggml_is_matrix (const struct ggml_tensor * tensor);
|
||||||
|
24
llama.h
24
llama.h
@ -39,7 +39,7 @@
|
|||||||
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
||||||
|
|
||||||
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
||||||
#define LLAMA_SESSION_VERSION 4
|
#define LLAMA_SESSION_VERSION 5
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
@ -678,23 +678,29 @@ extern "C" {
|
|||||||
LLAMA_API void llama_synchronize(struct llama_context * ctx);
|
LLAMA_API void llama_synchronize(struct llama_context * ctx);
|
||||||
|
|
||||||
// Token logits obtained from the last call to llama_decode()
|
// Token logits obtained from the last call to llama_decode()
|
||||||
// The logits for the last token are stored in the last row
|
// The logits for which llama_batch.logits[i] != 0 are stored contiguously
|
||||||
// Logits for which llama_batch.logits[i] == 0 are undefined
|
// in the order they have appeared in the batch.
|
||||||
// Rows: n_tokens provided with llama_batch
|
// Rows: number of tokens for which llama_batch.logits[i] != 0
|
||||||
// Cols: n_vocab
|
// Cols: n_vocab
|
||||||
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
|
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
|
||||||
|
|
||||||
// Logits for the ith token. Equivalent to:
|
// Logits for the ith token. Equivalent to:
|
||||||
// llama_get_logits(ctx) + i*n_vocab
|
// llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab
|
||||||
|
// returns NULL for invalid ids.
|
||||||
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
|
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
|
||||||
|
|
||||||
// Get all output token embeddings
|
// Get all output token embeddings.
|
||||||
// shape: [n_tokens*n_embd] (1-dimensional)
|
// when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model,
|
||||||
|
// the embeddings for which llama_batch.logits[i] != 0 are stored contiguously
|
||||||
|
// in the order they have appeared in the batch.
|
||||||
|
// shape: [n_outputs*n_embd]
|
||||||
|
// Otherwise, returns NULL.
|
||||||
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
||||||
|
|
||||||
// Get the embeddings for the ith token
|
// Get the embeddings for the ith token. Equivalent to:
|
||||||
// llama_get_embeddings(ctx) + i*n_embd
|
// llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
|
||||||
// shape: [n_embd] (1-dimensional)
|
// shape: [n_embd] (1-dimensional)
|
||||||
|
// returns NULL for invalid ids.
|
||||||
LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
|
LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
|
||||||
|
|
||||||
// Get the embeddings for a sequence id
|
// Get the embeddings for a sequence id
|
||||||
|
Loading…
Reference in New Issue
Block a user