mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-07 09:11:46 +00:00
llama : avoid copies for simple batch splits
This commit is contained in:
parent
61200ef29f
commit
eb589d5e36
81
llama.cpp
81
llama.cpp
@ -3143,19 +3143,29 @@ struct llama_sbatch {
|
|||||||
GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
|
GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
|
||||||
// NOTE: loops are separated for cache-friendliness
|
// NOTE: loops are separated for cache-friendliness
|
||||||
if (batch->token) {
|
if (batch->token) {
|
||||||
for (size_t i = 0; i < length; ++i) {
|
if (ubatch.equal_seqs) {
|
||||||
ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
|
for (size_t i = 0; i < length; ++i) {
|
||||||
|
ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// simple split
|
||||||
|
ubatch.token = batch->token + seq.offset;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
ubatch.token = nullptr;
|
ubatch.token = nullptr;
|
||||||
}
|
}
|
||||||
if (batch->embd) {
|
if (batch->embd) {
|
||||||
for (size_t i = 0; i < length; ++i) {
|
if (ubatch.equal_seqs) {
|
||||||
memcpy(
|
for (size_t i = 0; i < length; ++i) {
|
||||||
ubatch.embd + n_embd * (ubatch.n_tokens + i),
|
memcpy(
|
||||||
batch->embd + n_embd * ids[seq.offset + i],
|
ubatch.embd + n_embd * (ubatch.n_tokens + i),
|
||||||
n_embd * sizeof(float)
|
batch->embd + n_embd * ids[seq.offset + i],
|
||||||
);
|
n_embd * sizeof(float)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// simple split
|
||||||
|
ubatch.embd = batch->embd + seq.offset;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
ubatch.embd = nullptr;
|
ubatch.embd = nullptr;
|
||||||
@ -3163,8 +3173,13 @@ struct llama_sbatch {
|
|||||||
// from here on, the else branches are deprecated;
|
// from here on, the else branches are deprecated;
|
||||||
// they are helpers for smoother batch API transition
|
// they are helpers for smoother batch API transition
|
||||||
if (batch->pos) {
|
if (batch->pos) {
|
||||||
for (size_t i = 0; i < length; ++i) {
|
if (ubatch.equal_seqs) {
|
||||||
ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
|
for (size_t i = 0; i < length; ++i) {
|
||||||
|
ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// simple split
|
||||||
|
ubatch.pos = batch->pos + seq.offset;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (size_t i = 0; i < length; ++i) {
|
for (size_t i = 0; i < length; ++i) {
|
||||||
@ -3172,7 +3187,7 @@ struct llama_sbatch {
|
|||||||
ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1);
|
ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (seq.n_seq_id > 0) {
|
if (ubatch.equal_seqs) {
|
||||||
ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
|
ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
|
||||||
if (seq.seq_id) {
|
if (seq.seq_id) {
|
||||||
ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
|
ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
|
||||||
@ -3181,9 +3196,10 @@ struct llama_sbatch {
|
|||||||
ubatch.seq_id[ubatch.n_seqs] = &seq.all_seq_id;
|
ubatch.seq_id[ubatch.n_seqs] = &seq.all_seq_id;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
// simple split
|
||||||
if (batch->n_seq_id) {
|
if (batch->n_seq_id) {
|
||||||
for (size_t i = 0; i < length; ++i) {
|
for (size_t i = 0; i < length; ++i) {
|
||||||
ubatch.n_seq_id[ubatch.n_seqs + i] = batch->n_seq_id[ids[seq.offset + i]];
|
ubatch.n_seq_id = batch->n_seq_id + seq.offset;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (size_t i = 0; i < length; ++i) {
|
for (size_t i = 0; i < length; ++i) {
|
||||||
@ -3192,7 +3208,7 @@ struct llama_sbatch {
|
|||||||
}
|
}
|
||||||
if (batch->seq_id) {
|
if (batch->seq_id) {
|
||||||
for (size_t i = 0; i < length; ++i) {
|
for (size_t i = 0; i < length; ++i) {
|
||||||
ubatch.seq_id[ubatch.n_seqs + i] = batch->seq_id[ids[seq.offset + i]];
|
ubatch.seq_id = batch->seq_id + seq.offset;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (size_t i = 0; i < length; ++i) {
|
for (size_t i = 0; i < length; ++i) {
|
||||||
@ -3201,11 +3217,19 @@ struct llama_sbatch {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (batch->logits) {
|
if (batch->logits) {
|
||||||
for (size_t i = 0; i < length; ++i) {
|
if (ubatch.equal_seqs) {
|
||||||
size_t id = ids[seq.offset + i];
|
for (size_t i = 0; i < length; ++i) {
|
||||||
int8_t is_output = batch->logits[id];
|
size_t id = ids[seq.offset + i];
|
||||||
ubatch.output[ubatch.n_tokens + i] = is_output;
|
int8_t is_output = batch->logits[id];
|
||||||
if (is_output) { out_ids.push_back(id); }
|
ubatch.output[ubatch.n_tokens + i] = is_output;
|
||||||
|
if (is_output) { out_ids.push_back(id); }
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// simple split
|
||||||
|
ubatch.output = batch->logits + seq.offset;
|
||||||
|
for (size_t i = 0; i < length; ++i) {
|
||||||
|
if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else if (logits_all) {
|
} else if (logits_all) {
|
||||||
for (size_t i = 0; i < length; ++i) {
|
for (size_t i = 0; i < length; ++i) {
|
||||||
@ -3222,18 +3246,18 @@ struct llama_sbatch {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
|
if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
|
||||||
ubatch.n_seq_tokens = seq.n_seq_id > 0 ? length : 1;
|
ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
|
||||||
}
|
}
|
||||||
ubatch.n_tokens += length;
|
ubatch.n_tokens += length;
|
||||||
ubatch.n_seqs += seq.n_seq_id > 0 ? 1 : length; // virtual sequences for legacy splits
|
ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
|
||||||
seq.offset += length;
|
seq.offset += length;
|
||||||
seq.length -= length;
|
seq.length -= length;
|
||||||
n_tokens -= length;
|
n_tokens -= length;
|
||||||
GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
|
GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
|
||||||
}
|
}
|
||||||
|
|
||||||
// legacy split, unknown number of sequences of unequal lengths
|
// simple split, unknown number of sequences of unequal lengths
|
||||||
llama_ubatch split_slice(size_t n_ubatch) {
|
llama_ubatch split_simple(size_t n_ubatch) {
|
||||||
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
|
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
|
||||||
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
|
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
|
||||||
ubatch.equal_seqs = false;
|
ubatch.equal_seqs = false;
|
||||||
@ -3241,7 +3265,6 @@ struct llama_sbatch {
|
|||||||
llama_sbatch_seq & s = seq[0];
|
llama_sbatch_seq & s = seq[0];
|
||||||
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
|
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
|
||||||
GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
|
GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
|
||||||
// TODO: reduce copies
|
|
||||||
add_seq_to_ubatch(ubatch, s, length);
|
add_seq_to_ubatch(ubatch, s, length);
|
||||||
}
|
}
|
||||||
return ubatch;
|
return ubatch;
|
||||||
@ -3254,7 +3277,7 @@ struct llama_sbatch {
|
|||||||
if (!seq.empty()) {
|
if (!seq.empty()) {
|
||||||
size_t length = 0;
|
size_t length = 0;
|
||||||
size_t n_tokens_in_ubatch = 0;
|
size_t n_tokens_in_ubatch = 0;
|
||||||
GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with legacy splits
|
GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
|
||||||
// smallest first, because it's easier to split this way;
|
// smallest first, because it's easier to split this way;
|
||||||
// starting from the end to pop in constant time.
|
// starting from the end to pop in constant time.
|
||||||
for (size_t i = seq.size(); i-- > 0;) {
|
for (size_t i = seq.size(); i-- > 0;) {
|
||||||
@ -3282,13 +3305,13 @@ struct llama_sbatch {
|
|||||||
if (!seq.empty()) {
|
if (!seq.empty()) {
|
||||||
llama_sbatch_seq & s = seq[seq.size() - 1];
|
llama_sbatch_seq & s = seq[seq.size() - 1];
|
||||||
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
|
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
|
||||||
GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with legacy splits
|
GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
|
||||||
add_seq_to_ubatch(ubatch, s, length);
|
add_seq_to_ubatch(ubatch, s, length);
|
||||||
}
|
}
|
||||||
return ubatch;
|
return ubatch;
|
||||||
}
|
}
|
||||||
|
|
||||||
void from_batch(const llama_batch & batch, const size_t n_embd, const bool legacy_split = false, const bool logits_all = false) {
|
void from_batch(const llama_batch & batch, const size_t n_embd, const bool simple_split = false, const bool logits_all = false) {
|
||||||
GGML_ASSERT(batch.n_tokens >= 0);
|
GGML_ASSERT(batch.n_tokens >= 0);
|
||||||
this->batch = &batch;
|
this->batch = &batch;
|
||||||
this->n_embd = n_embd;
|
this->n_embd = n_embd;
|
||||||
@ -3302,7 +3325,7 @@ struct llama_sbatch {
|
|||||||
for (size_t i = 0; i < n_tokens; ++i) {
|
for (size_t i = 0; i < n_tokens; ++i) {
|
||||||
ids[i] = i;
|
ids[i] = i;
|
||||||
}
|
}
|
||||||
if (legacy_split) {
|
if (simple_split) {
|
||||||
seq.resize(1);
|
seq.resize(1);
|
||||||
llama_sbatch_seq & s = seq[0];
|
llama_sbatch_seq & s = seq[0];
|
||||||
s.n_seq_id = 0;
|
s.n_seq_id = 0;
|
||||||
@ -13737,7 +13760,7 @@ static int llama_decode_internal(
|
|||||||
}
|
}
|
||||||
|
|
||||||
lctx.sbatch.from_batch(batch_all, n_embd,
|
lctx.sbatch.from_batch(batch_all, n_embd,
|
||||||
/* legacy_split */ rs_self.size == 0,
|
/* simple_split */ rs_self.size == 0,
|
||||||
/* logits_all */ n_outputs == n_tokens_all);
|
/* logits_all */ n_outputs == n_tokens_all);
|
||||||
|
|
||||||
// reserve output buffer
|
// reserve output buffer
|
||||||
@ -13749,7 +13772,7 @@ static int llama_decode_internal(
|
|||||||
while (lctx.sbatch.n_tokens > 0) {
|
while (lctx.sbatch.n_tokens > 0) {
|
||||||
// TODO: deprecate slice splits in favor of equal splits
|
// TODO: deprecate slice splits in favor of equal splits
|
||||||
// For now, only use equal splits for recurrent or hybrid model architectures
|
// For now, only use equal splits for recurrent or hybrid model architectures
|
||||||
llama_ubatch u_batch = (rs_self.size > 0) ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_slice(n_ubatch);
|
llama_ubatch u_batch = (rs_self.size > 0) ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_simple(n_ubatch);
|
||||||
const uint32_t n_tokens = u_batch.n_tokens;
|
const uint32_t n_tokens = u_batch.n_tokens;
|
||||||
|
|
||||||
// count the outputs in this u_batch
|
// count the outputs in this u_batch
|
||||||
|
Loading…
Reference in New Issue
Block a user