mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-05 16:24:34 +00:00
llama : sequence-length-aware batch splitting
This commit is contained in:
parent
181dadf294
commit
3a414b0be2
443
llama.cpp
443
llama.cpp
@ -2807,6 +2807,321 @@ struct llama_model {
|
||||
}
|
||||
};
|
||||
|
||||
// very similar to llama_batch,
|
||||
// but has more metadata about sequences
|
||||
struct llama_ubatch {
|
||||
bool equal_seqs;
|
||||
|
||||
int32_t n_tokens;
|
||||
int32_t n_seqs;
|
||||
|
||||
llama_token * token;
|
||||
float * embd;
|
||||
llama_pos * pos;
|
||||
int32_t * n_seq_id;
|
||||
llama_seq_id ** seq_id;
|
||||
int8_t * output;
|
||||
};
|
||||
|
||||
struct llama_sbatch_seq {
|
||||
int32_t n_seq_id;
|
||||
llama_seq_id * seq_id;
|
||||
size_t offset;
|
||||
size_t length;
|
||||
|
||||
// helper for smoother batch API transition -- can be deprecated in the future
|
||||
llama_seq_id all_seq_id; // used if seq_id == NULL
|
||||
};
|
||||
|
||||
// sequence-length-aware batch splitting
|
||||
struct llama_sbatch {
|
||||
// tokens left in this batch
|
||||
size_t n_tokens;
|
||||
|
||||
size_t n_embd;
|
||||
|
||||
bool logits_all; // TODO: remove once lctx.logits_all is removed too
|
||||
|
||||
// sorted indices into the batch
|
||||
std::vector<size_t> ids;
|
||||
// batch indices of the output
|
||||
std::vector<size_t> out_ids;
|
||||
std::vector<llama_sbatch_seq> seq;
|
||||
const llama_batch * batch = nullptr;
|
||||
|
||||
// buffers for the ubatch
|
||||
std::vector<llama_token> ubatch_token;
|
||||
std::vector<float> ubatch_embd;
|
||||
std::vector<llama_pos> ubatch_pos;
|
||||
std::vector<int32_t> ubatch_n_seq_id;
|
||||
std::vector<llama_seq_id *> ubatch_seq_id;
|
||||
std::vector<int8_t> ubatch_output;
|
||||
|
||||
llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false) {
|
||||
// clear empty sequences
|
||||
// the previous ubatch is assumed to be gone,
|
||||
// so nothing should refer to values in these sequences anymore.
|
||||
for (size_t i = seq.size(); i-- > 0;) {
|
||||
if (seq[i].length == 0) {
|
||||
seq.pop_back();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
ubatch_token.resize(!has_embd ? n_ubatch : 0);
|
||||
ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
|
||||
ubatch_pos.resize(n_ubatch);
|
||||
ubatch_n_seq_id.resize(n_ubatch);
|
||||
ubatch_seq_id.resize(n_ubatch);
|
||||
ubatch_output.resize(n_ubatch);
|
||||
llama_ubatch ubatch = {
|
||||
true,
|
||||
0,
|
||||
0,
|
||||
!has_embd ? ubatch_token.data() : nullptr,
|
||||
has_embd ? ubatch_embd.data() : nullptr,
|
||||
ubatch_pos.data(),
|
||||
ubatch_n_seq_id.data(),
|
||||
ubatch_seq_id.data(),
|
||||
ubatch_output.data(),
|
||||
};
|
||||
return ubatch;
|
||||
}
|
||||
|
||||
void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
|
||||
GGML_ASSERT(batch != nullptr);
|
||||
GGML_ASSERT(length <= seq.length);
|
||||
if (ubatch.equal_seqs) {
|
||||
// is the new sequence of a different size than expected?
|
||||
if (ubatch.n_seqs > 0 && length != (size_t) ubatch.n_tokens / ubatch.n_seqs) {
|
||||
ubatch.equal_seqs = false;
|
||||
}
|
||||
}
|
||||
// NOTE: loops are separated for cache-friendliness
|
||||
if (batch->token) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
|
||||
}
|
||||
} else {
|
||||
ubatch.token = nullptr;
|
||||
}
|
||||
if (batch->embd) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
memcpy(
|
||||
ubatch.embd + n_embd * (ubatch.n_tokens + i),
|
||||
batch->embd + n_embd * ids[seq.offset + i],
|
||||
n_embd * sizeof(float)
|
||||
);
|
||||
}
|
||||
} else {
|
||||
ubatch.embd = nullptr;
|
||||
}
|
||||
// from here on, the else branches are deprecated;
|
||||
// they are helpers for smoother batch API transition
|
||||
if (batch->pos) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
llama_pos bi = ids[seq.offset + i];
|
||||
ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1);
|
||||
}
|
||||
}
|
||||
if (batch->n_seq_id) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
ubatch.n_seq_id[ubatch.n_tokens + i] = batch->n_seq_id[ids[seq.offset + i]];
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
ubatch.n_seq_id[ubatch.n_tokens + i] = 1;
|
||||
}
|
||||
}
|
||||
if (batch->seq_id) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
ubatch.seq_id[ubatch.n_tokens + i] = batch->seq_id[ids[seq.offset + i]];
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
ubatch.seq_id[ubatch.n_tokens + i] = &seq.all_seq_id;
|
||||
}
|
||||
}
|
||||
if (batch->logits) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
size_t id = ids[seq.offset + i];
|
||||
int8_t is_output = batch->logits[id];
|
||||
ubatch.output[ubatch.n_tokens + i] = is_output;
|
||||
if (is_output) { out_ids.push_back(id); }
|
||||
}
|
||||
} else if (logits_all) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
ubatch.output[ubatch.n_tokens + i] = 1;
|
||||
out_ids.push_back(ids[seq.offset + i]);
|
||||
}
|
||||
} else {
|
||||
// only get last output
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
size_t id = ids[seq.offset + i];
|
||||
int8_t is_last = id == ids.size() - 1;
|
||||
ubatch.output[ubatch.n_tokens + i] = is_last;
|
||||
if (is_last) { out_ids.push_back(id); }
|
||||
}
|
||||
}
|
||||
ubatch.n_tokens += length;
|
||||
ubatch.n_seqs += seq.n_seq_id != 0; // don't count seq_ids for legacy splits
|
||||
seq.offset += length;
|
||||
seq.length -= length;
|
||||
n_tokens -= length;
|
||||
}
|
||||
|
||||
// legacy split, unknown number of sequences of unequal lengths
|
||||
llama_ubatch split_slice(size_t n_ubatch) {
|
||||
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
|
||||
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
|
||||
ubatch.equal_seqs = false;
|
||||
if (!seq.empty()) {
|
||||
llama_sbatch_seq & s = seq[0];
|
||||
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
|
||||
GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
|
||||
// TODO: reduce copies
|
||||
add_seq_to_ubatch(ubatch, s, length);
|
||||
}
|
||||
return ubatch;
|
||||
}
|
||||
|
||||
// make batches of equal-length sequences
|
||||
llama_ubatch split_equal(size_t n_ubatch) {
|
||||
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
|
||||
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
|
||||
if (!seq.empty()) {
|
||||
size_t length = 0;
|
||||
size_t n_tokens_in_ubatch = 0;
|
||||
GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with legacy splits
|
||||
// smallest first, because it's easier to split this way;
|
||||
// starting from the end to pop in constant time.
|
||||
for (size_t i = seq.size(); i-- > 0;) {
|
||||
llama_sbatch_seq & s = seq[i];
|
||||
GGML_ASSERT(s.length > 0);
|
||||
if (length == 0) {
|
||||
length = s.length < n_ubatch ? s.length : n_ubatch;
|
||||
}
|
||||
add_seq_to_ubatch(ubatch, s, length);
|
||||
n_tokens_in_ubatch += length;
|
||||
// shared prompts can't be mixed with any of their sequences,
|
||||
// so it's safer to compute them in their own ubatch
|
||||
if (s.n_seq_id > 1) { break; }
|
||||
// stop when there isn't enough space for another sequence
|
||||
if (length + n_tokens_in_ubatch > n_ubatch) { break; }
|
||||
}
|
||||
}
|
||||
return ubatch;
|
||||
}
|
||||
|
||||
// sequence-wise split
|
||||
llama_ubatch split_seq(size_t n_ubatch) {
|
||||
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
|
||||
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
|
||||
if (!seq.empty()) {
|
||||
llama_sbatch_seq & s = seq[seq.size() - 1];
|
||||
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
|
||||
GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with legacy splits
|
||||
add_seq_to_ubatch(ubatch, s, length);
|
||||
}
|
||||
return ubatch;
|
||||
}
|
||||
|
||||
void from_batch(const llama_batch & batch, const size_t n_embd, const bool legacy_split = false, const bool logits_all = false) {
|
||||
GGML_ASSERT(batch.n_tokens >= 0);
|
||||
this->batch = &batch;
|
||||
this->n_embd = n_embd;
|
||||
this->logits_all = logits_all;
|
||||
|
||||
n_tokens = batch.n_tokens;
|
||||
ids.resize(n_tokens);
|
||||
out_ids.clear();
|
||||
// TODO: reserve out_ids and seq
|
||||
|
||||
for (size_t i = 0; i < n_tokens; ++i) {
|
||||
ids[i] = i;
|
||||
}
|
||||
if (legacy_split) {
|
||||
seq.resize(1);
|
||||
llama_sbatch_seq & s = seq[0];
|
||||
s.n_seq_id = 0;
|
||||
s.seq_id = nullptr;
|
||||
s.offset = 0;
|
||||
s.length = n_tokens;
|
||||
s.all_seq_id = batch.all_seq_id;
|
||||
return;
|
||||
}
|
||||
std::sort(ids.begin(), ids.end(),
|
||||
[batch](size_t a, size_t b) {
|
||||
int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
|
||||
int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
|
||||
// sort by seq_id, then by pos
|
||||
if (n_seq_a == n_seq_b) {
|
||||
if (batch.seq_id) {
|
||||
for (int32_t i = 0; i < n_seq_a; ++i) {
|
||||
llama_seq_id seq_id_a = batch.seq_id[a][i];
|
||||
llama_seq_id seq_id_b = batch.seq_id[b][i];
|
||||
// smaller seq_ids go first
|
||||
if (seq_id_a != seq_id_b) {
|
||||
return seq_id_a < seq_id_b;
|
||||
}
|
||||
}
|
||||
}
|
||||
// when all else is equal, sort by pos
|
||||
if (batch.pos) {
|
||||
return batch.pos[a] < batch.pos[b];
|
||||
}
|
||||
// no pos, sort by id (assuming batch.all_pos_1 is positive)
|
||||
return a < b;
|
||||
}
|
||||
// shared prompts go first
|
||||
return n_seq_a > n_seq_b;
|
||||
}
|
||||
);
|
||||
// init seq
|
||||
llama_sbatch_seq * last_seq = nullptr;
|
||||
|
||||
if (batch.n_seq_id != nullptr && batch.seq_id != nullptr) {
|
||||
for (size_t i = 0; i < n_tokens; ++i) {
|
||||
const size_t bi = ids[i];
|
||||
const size_t s_len = seq.size();
|
||||
const int32_t n_seqs = batch.n_seq_id[bi];
|
||||
llama_seq_id * seq_ids = batch.seq_id[bi];
|
||||
if (last_seq != nullptr) {
|
||||
bool same = n_seqs == last_seq->n_seq_id;
|
||||
for (int32_t j = 0; same && j < n_seqs; ++j) {
|
||||
if (seq_ids[j] != last_seq->seq_id[j]) {
|
||||
same = false;
|
||||
}
|
||||
}
|
||||
if (same) {
|
||||
last_seq->length += 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1, batch.all_seq_id};
|
||||
seq.push_back(new_seq);
|
||||
last_seq = &seq[s_len];
|
||||
}
|
||||
} else {
|
||||
llama_sbatch_seq new_seq = {1, nullptr, 0, n_tokens, batch.all_seq_id};
|
||||
seq.push_back(new_seq);
|
||||
}
|
||||
// keep shared prompts first at the end, then sort by length descending.
|
||||
std::sort(seq.begin(), seq.end(),
|
||||
[](llama_sbatch_seq & a, llama_sbatch_seq & b) {
|
||||
if (a.n_seq_id == b.n_seq_id) {
|
||||
return a.length > b.length;
|
||||
}
|
||||
return a.n_seq_id < b.n_seq_id;
|
||||
}
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
struct llama_context {
|
||||
llama_context(const llama_model & model) : model(model), t_start_us(model.t_start_us), t_load_us(model.t_load_us) {}
|
||||
~llama_context() {
|
||||
@ -2832,6 +3147,9 @@ struct llama_context {
|
||||
// key + value cache for self-attention, and/or recurrent state cache
|
||||
struct llama_cache cache;
|
||||
|
||||
// sequence-length-aware batch splitting
|
||||
llama_sbatch sbatch;
|
||||
|
||||
std::mt19937 rng;
|
||||
|
||||
bool has_evaluated_once = false;
|
||||
@ -3126,7 +3444,7 @@ static bool llama_cache_init(
|
||||
// to the first cell of the slot.
|
||||
static bool llama_cache_find_slot(
|
||||
struct llama_cache & cache,
|
||||
const struct llama_batch & batch) {
|
||||
const struct llama_ubatch & batch) {
|
||||
const uint32_t kv_size = cache.kv.size;
|
||||
const uint32_t rs_size = cache.rs.size;
|
||||
const uint32_t n_tokens = batch.n_tokens;
|
||||
@ -7533,7 +7851,7 @@ static struct ggml_tensor * llm_build_inp_embd(
|
||||
struct ggml_context * ctx,
|
||||
struct llama_context & lctx,
|
||||
const llama_hparams & hparams,
|
||||
const llama_batch & batch,
|
||||
const llama_ubatch & batch,
|
||||
struct ggml_tensor * tok_embd,
|
||||
const llm_build_cb & cb) {
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
@ -8107,7 +8425,7 @@ struct llm_build_context {
|
||||
llama_context & lctx;
|
||||
const llama_hparams & hparams;
|
||||
const llama_cparams & cparams;
|
||||
const llama_batch & batch;
|
||||
const llama_ubatch & batch;
|
||||
const llama_kv_cache & kv_self;
|
||||
const llama_rs_cache & rs_self;
|
||||
|
||||
@ -8153,7 +8471,7 @@ struct llm_build_context {
|
||||
// TODO: consider making the entire interface noexcept
|
||||
llm_build_context(
|
||||
llama_context & lctx,
|
||||
const llama_batch & batch,
|
||||
const llama_ubatch & batch,
|
||||
const llm_build_cb & cb,
|
||||
bool worst_case) :
|
||||
model (lctx.model),
|
||||
@ -12215,8 +12533,8 @@ struct llm_build_context {
|
||||
};
|
||||
|
||||
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
|
||||
llama_batch dummy;
|
||||
dummy.n_tokens = 0;
|
||||
llama_ubatch dummy = {};
|
||||
dummy.equal_seqs = true;
|
||||
|
||||
llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
|
||||
|
||||
@ -12232,8 +12550,8 @@ static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const
|
||||
}
|
||||
|
||||
static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
|
||||
llama_batch dummy;
|
||||
dummy.n_tokens = 0;
|
||||
llama_ubatch dummy = {};
|
||||
dummy.equal_seqs = true;
|
||||
|
||||
llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
|
||||
|
||||
@ -12250,7 +12568,7 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
|
||||
|
||||
static struct ggml_cgraph * llama_build_graph(
|
||||
llama_context & lctx,
|
||||
const llama_batch & batch,
|
||||
const llama_ubatch & batch,
|
||||
bool worst_case) {
|
||||
const auto & model = lctx.model;
|
||||
|
||||
@ -12438,7 +12756,7 @@ static void llama_set_k_shift(llama_context & lctx) {
|
||||
}
|
||||
}
|
||||
|
||||
static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
|
||||
//
|
||||
// set input data
|
||||
//
|
||||
@ -12478,10 +12796,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
data[i] = i;
|
||||
}
|
||||
} else if (batch.logits) {
|
||||
} else if (batch.output) {
|
||||
int32_t n_outputs = 0;
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
if (batch.logits[i]) {
|
||||
if (batch.output[i]) {
|
||||
data[n_outputs++] = i;
|
||||
}
|
||||
}
|
||||
@ -12835,11 +13153,6 @@ static int llama_decode_internal(
|
||||
|
||||
const auto n_ubatch = cparams.n_ubatch;
|
||||
|
||||
std::vector<llama_pos> pos;
|
||||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<llama_seq_id *> seq_id_arr;
|
||||
std::vector<std::vector<llama_seq_id>> seq_id;
|
||||
|
||||
// count outputs
|
||||
if (batch_all.logits) {
|
||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||
@ -12852,55 +13165,29 @@ static int llama_decode_internal(
|
||||
n_outputs = 1;
|
||||
}
|
||||
|
||||
lctx.sbatch.from_batch(batch_all, n_embd, /* legacy_split */ rs_self.size == 0, lctx.logits_all);
|
||||
|
||||
// reserve output buffer
|
||||
if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
|
||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs);
|
||||
return -2;
|
||||
};
|
||||
|
||||
// set output mappings
|
||||
if (batch_all.logits) {
|
||||
int32_t i_logits = 0;
|
||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||
if (batch_all.logits[i]) {
|
||||
lctx.output_ids[i] = i_logits++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (uint32_t i = 0; i < n_outputs; ++i) {
|
||||
lctx.output_ids[i] = i;
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) {
|
||||
const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
|
||||
llama_batch u_batch = {
|
||||
/* .n_tokens = */ (int32_t) n_tokens,
|
||||
/* .token = */ batch_all.token ? batch_all.token + cur_token : nullptr,
|
||||
/* .embd = */ batch_all.embd ? batch_all.embd + cur_token*n_embd : nullptr,
|
||||
/* .pos = */ batch_all.pos ? batch_all.pos + cur_token : nullptr,
|
||||
/* .n_seq_id = */ batch_all.n_seq_id ? batch_all.n_seq_id + cur_token : nullptr,
|
||||
/* .seq_id = */ batch_all.seq_id ? batch_all.seq_id + cur_token : nullptr,
|
||||
/* .logits = */ batch_all.logits ? batch_all.logits + cur_token : nullptr,
|
||||
/* .all_pos_0 = */ batch_all.all_pos_0 + (llama_pos) cur_token*batch_all.all_pos_1,
|
||||
/* .all_pos_1 = */ batch_all.all_pos_1,
|
||||
/* .all_seq_id = */ batch_all.all_seq_id,
|
||||
};
|
||||
while (lctx.sbatch.n_tokens > 0) {
|
||||
// TODO: deprecate slice splits in favor of equal splits
|
||||
llama_ubatch u_batch = (rs_self.size > 0) ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_slice(n_ubatch);
|
||||
const uint32_t n_tokens = u_batch.n_tokens;
|
||||
|
||||
// count the outputs in this u_batch
|
||||
{
|
||||
int32_t n_outputs_new = 0;
|
||||
|
||||
if (u_batch.logits) {
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
n_outputs_new += u_batch.logits[i] != 0;
|
||||
}
|
||||
} else if (n_outputs == n_tokens_all) {
|
||||
if (n_outputs == n_tokens_all) {
|
||||
n_outputs_new = n_tokens;
|
||||
} else {
|
||||
// keep last output only
|
||||
if (cur_token + n_tokens >= n_tokens_all) {
|
||||
n_outputs_new = 1;
|
||||
GGML_ASSERT(u_batch.output);
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
n_outputs_new += u_batch.output[i] != 0;
|
||||
}
|
||||
}
|
||||
|
||||
@ -12911,32 +13198,6 @@ static int llama_decode_internal(
|
||||
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
|
||||
GGML_ASSERT(n_threads > 0);
|
||||
|
||||
// helpers for smoother batch API transition
|
||||
// after deprecating the llama_eval calls, these will be removed
|
||||
if (u_batch.pos == nullptr) {
|
||||
pos.resize(n_tokens);
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
pos[i] = u_batch.all_pos_0 + i*u_batch.all_pos_1;
|
||||
}
|
||||
|
||||
u_batch.pos = pos.data();
|
||||
}
|
||||
|
||||
if (u_batch.seq_id == nullptr) {
|
||||
n_seq_id.resize(n_tokens);
|
||||
seq_id.resize(n_tokens);
|
||||
seq_id_arr.resize(n_tokens);
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
n_seq_id[i] = 1;
|
||||
seq_id[i].resize(1);
|
||||
seq_id[i][0] = u_batch.all_seq_id;
|
||||
seq_id_arr[i] = seq_id[i].data();
|
||||
}
|
||||
|
||||
u_batch.n_seq_id = n_seq_id.data();
|
||||
u_batch.seq_id = seq_id_arr.data();
|
||||
}
|
||||
|
||||
// non-causal masks do not use the KV cache
|
||||
if (hparams.causal_attn) {
|
||||
llama_kv_cache_update(&lctx);
|
||||
@ -12945,6 +13206,7 @@ static int llama_decode_internal(
|
||||
return 1;
|
||||
}
|
||||
|
||||
// TODO: move into llama_cache_find_slot
|
||||
if (kv_self.size > 0) {
|
||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// after enough generations, the benefit from this heuristic disappears
|
||||
@ -13108,6 +13370,12 @@ static int llama_decode_internal(
|
||||
#endif
|
||||
}
|
||||
|
||||
// set output mappings
|
||||
GGML_ASSERT(lctx.sbatch.out_ids.size() == n_outputs);
|
||||
for (size_t i = 0; i < n_outputs; ++i) {
|
||||
lctx.output_ids[lctx.sbatch.out_ids[i]] = i;
|
||||
}
|
||||
|
||||
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
||||
lctx.n_outputs = n_outputs;
|
||||
|
||||
@ -13398,10 +13666,11 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
||||
if (need_reserve) {
|
||||
// TODO: extract to a function
|
||||
// build worst-case graph
|
||||
int n_seqs = 1; // TODO: worst-case number of sequences
|
||||
int n_tokens = (int)std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch);
|
||||
int n_past = lctx.cparams.n_ctx - n_tokens;
|
||||
llama_token token = llama_token_bos(&lctx.model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||
ggml_cgraph * gf = llama_build_graph(lctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true);
|
||||
llama_ubatch ubatch = { true, n_tokens, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
|
||||
|
||||
// initialize scheduler with the worst-case graph
|
||||
ggml_backend_sched_reset(lctx.sched);
|
||||
@ -17345,10 +17614,11 @@ struct llama_context * llama_new_context_with_model(
|
||||
}
|
||||
|
||||
// build worst-case graph
|
||||
int n_seqs = 1; // TODO: worst-case number of sequences
|
||||
int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
int n_past = cparams.n_ctx - n_tokens;
|
||||
llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||
ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true);
|
||||
llama_ubatch ubatch = { true, n_tokens, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
ggml_cgraph * gf = llama_build_graph(*ctx, ubatch, true);
|
||||
|
||||
// initialize scheduler with the worst-case graph
|
||||
if (!ggml_backend_sched_reserve(ctx->sched, gf)) {
|
||||
@ -18662,8 +18932,9 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
|
||||
|
||||
// Allocate the new cells for the slot
|
||||
if (cell_count) {
|
||||
llama_batch batch = llama_batch_init(cell_count, 0, 1);
|
||||
llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
|
||||
batch.n_tokens = cell_count;
|
||||
batch.n_seqs = 1;
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
llama_pos pos;
|
||||
memcpy(&pos, inp, sizeof(pos));
|
||||
@ -18674,7 +18945,6 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
|
||||
batch.seq_id[i][0] = dest_seq_id;
|
||||
}
|
||||
if (!llama_cache_find_slot(cache, batch)) {
|
||||
llama_batch_free(batch);
|
||||
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
||||
return 0;
|
||||
}
|
||||
@ -18686,9 +18956,6 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
|
||||
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
|
||||
GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
|
||||
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
|
||||
|
||||
// Cleanup
|
||||
llama_batch_free(batch);
|
||||
}
|
||||
|
||||
const uint32_t kv_size = kv_self.size;
|
||||
|
Loading…
Reference in New Issue
Block a user