mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-07 09:11:46 +00:00
llama : avoid copies for simple batch splits
This commit is contained in:
parent
61200ef29f
commit
eb589d5e36
51
llama.cpp
51
llama.cpp
@ -3143,13 +3143,19 @@ struct llama_sbatch {
|
||||
GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
|
||||
// NOTE: loops are separated for cache-friendliness
|
||||
if (batch->token) {
|
||||
if (ubatch.equal_seqs) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
|
||||
}
|
||||
} else {
|
||||
// simple split
|
||||
ubatch.token = batch->token + seq.offset;
|
||||
}
|
||||
} else {
|
||||
ubatch.token = nullptr;
|
||||
}
|
||||
if (batch->embd) {
|
||||
if (ubatch.equal_seqs) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
memcpy(
|
||||
ubatch.embd + n_embd * (ubatch.n_tokens + i),
|
||||
@ -3157,22 +3163,31 @@ struct llama_sbatch {
|
||||
n_embd * sizeof(float)
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// simple split
|
||||
ubatch.embd = batch->embd + seq.offset;
|
||||
}
|
||||
} else {
|
||||
ubatch.embd = nullptr;
|
||||
}
|
||||
// from here on, the else branches are deprecated;
|
||||
// they are helpers for smoother batch API transition
|
||||
if (batch->pos) {
|
||||
if (ubatch.equal_seqs) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
|
||||
}
|
||||
} else {
|
||||
// simple split
|
||||
ubatch.pos = batch->pos + seq.offset;
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
llama_pos bi = ids[seq.offset + i];
|
||||
ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1);
|
||||
}
|
||||
}
|
||||
if (seq.n_seq_id > 0) {
|
||||
if (ubatch.equal_seqs) {
|
||||
ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
|
||||
if (seq.seq_id) {
|
||||
ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
|
||||
@ -3181,9 +3196,10 @@ struct llama_sbatch {
|
||||
ubatch.seq_id[ubatch.n_seqs] = &seq.all_seq_id;
|
||||
}
|
||||
} else {
|
||||
// simple split
|
||||
if (batch->n_seq_id) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
ubatch.n_seq_id[ubatch.n_seqs + i] = batch->n_seq_id[ids[seq.offset + i]];
|
||||
ubatch.n_seq_id = batch->n_seq_id + seq.offset;
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
@ -3192,7 +3208,7 @@ struct llama_sbatch {
|
||||
}
|
||||
if (batch->seq_id) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
ubatch.seq_id[ubatch.n_seqs + i] = batch->seq_id[ids[seq.offset + i]];
|
||||
ubatch.seq_id = batch->seq_id + seq.offset;
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
@ -3201,12 +3217,20 @@ struct llama_sbatch {
|
||||
}
|
||||
}
|
||||
if (batch->logits) {
|
||||
if (ubatch.equal_seqs) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
size_t id = ids[seq.offset + i];
|
||||
int8_t is_output = batch->logits[id];
|
||||
ubatch.output[ubatch.n_tokens + i] = is_output;
|
||||
if (is_output) { out_ids.push_back(id); }
|
||||
}
|
||||
} else {
|
||||
// simple split
|
||||
ubatch.output = batch->logits + seq.offset;
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
|
||||
}
|
||||
}
|
||||
} else if (logits_all) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
ubatch.output[ubatch.n_tokens + i] = 1;
|
||||
@ -3222,18 +3246,18 @@ struct llama_sbatch {
|
||||
}
|
||||
}
|
||||
if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
|
||||
ubatch.n_seq_tokens = seq.n_seq_id > 0 ? length : 1;
|
||||
ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
|
||||
}
|
||||
ubatch.n_tokens += length;
|
||||
ubatch.n_seqs += seq.n_seq_id > 0 ? 1 : length; // virtual sequences for legacy splits
|
||||
ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
|
||||
seq.offset += length;
|
||||
seq.length -= length;
|
||||
n_tokens -= length;
|
||||
GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
|
||||
}
|
||||
|
||||
// legacy split, unknown number of sequences of unequal lengths
|
||||
llama_ubatch split_slice(size_t n_ubatch) {
|
||||
// simple split, unknown number of sequences of unequal lengths
|
||||
llama_ubatch split_simple(size_t n_ubatch) {
|
||||
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
|
||||
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
|
||||
ubatch.equal_seqs = false;
|
||||
@ -3241,7 +3265,6 @@ struct llama_sbatch {
|
||||
llama_sbatch_seq & s = seq[0];
|
||||
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
|
||||
GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
|
||||
// TODO: reduce copies
|
||||
add_seq_to_ubatch(ubatch, s, length);
|
||||
}
|
||||
return ubatch;
|
||||
@ -3254,7 +3277,7 @@ struct llama_sbatch {
|
||||
if (!seq.empty()) {
|
||||
size_t length = 0;
|
||||
size_t n_tokens_in_ubatch = 0;
|
||||
GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with legacy splits
|
||||
GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
|
||||
// smallest first, because it's easier to split this way;
|
||||
// starting from the end to pop in constant time.
|
||||
for (size_t i = seq.size(); i-- > 0;) {
|
||||
@ -3282,13 +3305,13 @@ struct llama_sbatch {
|
||||
if (!seq.empty()) {
|
||||
llama_sbatch_seq & s = seq[seq.size() - 1];
|
||||
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
|
||||
GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with legacy splits
|
||||
GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
|
||||
add_seq_to_ubatch(ubatch, s, length);
|
||||
}
|
||||
return ubatch;
|
||||
}
|
||||
|
||||
void from_batch(const llama_batch & batch, const size_t n_embd, const bool legacy_split = false, const bool logits_all = false) {
|
||||
void from_batch(const llama_batch & batch, const size_t n_embd, const bool simple_split = false, const bool logits_all = false) {
|
||||
GGML_ASSERT(batch.n_tokens >= 0);
|
||||
this->batch = &batch;
|
||||
this->n_embd = n_embd;
|
||||
@ -3302,7 +3325,7 @@ struct llama_sbatch {
|
||||
for (size_t i = 0; i < n_tokens; ++i) {
|
||||
ids[i] = i;
|
||||
}
|
||||
if (legacy_split) {
|
||||
if (simple_split) {
|
||||
seq.resize(1);
|
||||
llama_sbatch_seq & s = seq[0];
|
||||
s.n_seq_id = 0;
|
||||
@ -13737,7 +13760,7 @@ static int llama_decode_internal(
|
||||
}
|
||||
|
||||
lctx.sbatch.from_batch(batch_all, n_embd,
|
||||
/* legacy_split */ rs_self.size == 0,
|
||||
/* simple_split */ rs_self.size == 0,
|
||||
/* logits_all */ n_outputs == n_tokens_all);
|
||||
|
||||
// reserve output buffer
|
||||
@ -13749,7 +13772,7 @@ static int llama_decode_internal(
|
||||
while (lctx.sbatch.n_tokens > 0) {
|
||||
// TODO: deprecate slice splits in favor of equal splits
|
||||
// For now, only use equal splits for recurrent or hybrid model architectures
|
||||
llama_ubatch u_batch = (rs_self.size > 0) ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_slice(n_ubatch);
|
||||
llama_ubatch u_batch = (rs_self.size > 0) ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_simple(n_ubatch);
|
||||
const uint32_t n_tokens = u_batch.n_tokens;
|
||||
|
||||
// count the outputs in this u_batch
|
||||
|
Loading…
Reference in New Issue
Block a user