llama : avoid copies for simple batch splits

This commit is contained in:
Francis Couture-Harpin 2024-06-01 23:05:13 -04:00
parent 61200ef29f
commit eb589d5e36

View File

@ -3143,19 +3143,29 @@ struct llama_sbatch {
GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
// NOTE: loops are separated for cache-friendliness
if (batch->token) {
for (size_t i = 0; i < length; ++i) {
ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
if (ubatch.equal_seqs) {
for (size_t i = 0; i < length; ++i) {
ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
}
} else {
// simple split
ubatch.token = batch->token + seq.offset;
}
} else {
ubatch.token = nullptr;
}
if (batch->embd) {
for (size_t i = 0; i < length; ++i) {
memcpy(
ubatch.embd + n_embd * (ubatch.n_tokens + i),
batch->embd + n_embd * ids[seq.offset + i],
n_embd * sizeof(float)
);
if (ubatch.equal_seqs) {
for (size_t i = 0; i < length; ++i) {
memcpy(
ubatch.embd + n_embd * (ubatch.n_tokens + i),
batch->embd + n_embd * ids[seq.offset + i],
n_embd * sizeof(float)
);
}
} else {
// simple split
ubatch.embd = batch->embd + seq.offset;
}
} else {
ubatch.embd = nullptr;
@ -3163,8 +3173,13 @@ struct llama_sbatch {
// from here on, the else branches are deprecated;
// they are helpers for smoother batch API transition
if (batch->pos) {
for (size_t i = 0; i < length; ++i) {
ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
if (ubatch.equal_seqs) {
for (size_t i = 0; i < length; ++i) {
ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
}
} else {
// simple split
ubatch.pos = batch->pos + seq.offset;
}
} else {
for (size_t i = 0; i < length; ++i) {
@ -3172,7 +3187,7 @@ struct llama_sbatch {
ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1);
}
}
if (seq.n_seq_id > 0) {
if (ubatch.equal_seqs) {
ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
if (seq.seq_id) {
ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
@ -3181,9 +3196,10 @@ struct llama_sbatch {
ubatch.seq_id[ubatch.n_seqs] = &seq.all_seq_id;
}
} else {
// simple split
if (batch->n_seq_id) {
for (size_t i = 0; i < length; ++i) {
ubatch.n_seq_id[ubatch.n_seqs + i] = batch->n_seq_id[ids[seq.offset + i]];
ubatch.n_seq_id = batch->n_seq_id + seq.offset;
}
} else {
for (size_t i = 0; i < length; ++i) {
@ -3192,7 +3208,7 @@ struct llama_sbatch {
}
if (batch->seq_id) {
for (size_t i = 0; i < length; ++i) {
ubatch.seq_id[ubatch.n_seqs + i] = batch->seq_id[ids[seq.offset + i]];
ubatch.seq_id = batch->seq_id + seq.offset;
}
} else {
for (size_t i = 0; i < length; ++i) {
@ -3201,11 +3217,19 @@ struct llama_sbatch {
}
}
if (batch->logits) {
for (size_t i = 0; i < length; ++i) {
size_t id = ids[seq.offset + i];
int8_t is_output = batch->logits[id];
ubatch.output[ubatch.n_tokens + i] = is_output;
if (is_output) { out_ids.push_back(id); }
if (ubatch.equal_seqs) {
for (size_t i = 0; i < length; ++i) {
size_t id = ids[seq.offset + i];
int8_t is_output = batch->logits[id];
ubatch.output[ubatch.n_tokens + i] = is_output;
if (is_output) { out_ids.push_back(id); }
}
} else {
// simple split
ubatch.output = batch->logits + seq.offset;
for (size_t i = 0; i < length; ++i) {
if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
}
}
} else if (logits_all) {
for (size_t i = 0; i < length; ++i) {
@ -3222,18 +3246,18 @@ struct llama_sbatch {
}
}
if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
ubatch.n_seq_tokens = seq.n_seq_id > 0 ? length : 1;
ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
}
ubatch.n_tokens += length;
ubatch.n_seqs += seq.n_seq_id > 0 ? 1 : length; // virtual sequences for legacy splits
ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
seq.offset += length;
seq.length -= length;
n_tokens -= length;
GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
}
// legacy split, unknown number of sequences of unequal lengths
llama_ubatch split_slice(size_t n_ubatch) {
// simple split, unknown number of sequences of unequal lengths
llama_ubatch split_simple(size_t n_ubatch) {
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
ubatch.equal_seqs = false;
@ -3241,7 +3265,6 @@ struct llama_sbatch {
llama_sbatch_seq & s = seq[0];
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
// TODO: reduce copies
add_seq_to_ubatch(ubatch, s, length);
}
return ubatch;
@ -3254,7 +3277,7 @@ struct llama_sbatch {
if (!seq.empty()) {
size_t length = 0;
size_t n_tokens_in_ubatch = 0;
GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with legacy splits
GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
// smallest first, because it's easier to split this way;
// starting from the end to pop in constant time.
for (size_t i = seq.size(); i-- > 0;) {
@ -3282,13 +3305,13 @@ struct llama_sbatch {
if (!seq.empty()) {
llama_sbatch_seq & s = seq[seq.size() - 1];
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with legacy splits
GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
add_seq_to_ubatch(ubatch, s, length);
}
return ubatch;
}
void from_batch(const llama_batch & batch, const size_t n_embd, const bool legacy_split = false, const bool logits_all = false) {
void from_batch(const llama_batch & batch, const size_t n_embd, const bool simple_split = false, const bool logits_all = false) {
GGML_ASSERT(batch.n_tokens >= 0);
this->batch = &batch;
this->n_embd = n_embd;
@ -3302,7 +3325,7 @@ struct llama_sbatch {
for (size_t i = 0; i < n_tokens; ++i) {
ids[i] = i;
}
if (legacy_split) {
if (simple_split) {
seq.resize(1);
llama_sbatch_seq & s = seq[0];
s.n_seq_id = 0;
@ -13737,7 +13760,7 @@ static int llama_decode_internal(
}
lctx.sbatch.from_batch(batch_all, n_embd,
/* legacy_split */ rs_self.size == 0,
/* simple_split */ rs_self.size == 0,
/* logits_all */ n_outputs == n_tokens_all);
// reserve output buffer
@ -13749,7 +13772,7 @@ static int llama_decode_internal(
while (lctx.sbatch.n_tokens > 0) {
// TODO: deprecate slice splits in favor of equal splits
// For now, only use equal splits for recurrent or hybrid model architectures
llama_ubatch u_batch = (rs_self.size > 0) ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_slice(n_ubatch);
llama_ubatch u_batch = (rs_self.size > 0) ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_simple(n_ubatch);
const uint32_t n_tokens = u_batch.n_tokens;
// count the outputs in this u_batch