From eb589d5e3664b784aef5326aa14dd21889eb1948 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 1 Jun 2024 23:05:13 -0400 Subject: [PATCH] llama : avoid copies for simple batch splits --- llama.cpp | 81 +++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 52 insertions(+), 29 deletions(-) diff --git a/llama.cpp b/llama.cpp index 62d66c2bc..ce96d7b55 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3143,19 +3143,29 @@ struct llama_sbatch { GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs); // NOTE: loops are separated for cache-friendliness if (batch->token) { - for (size_t i = 0; i < length; ++i) { - ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]]; + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]]; + } + } else { + // simple split + ubatch.token = batch->token + seq.offset; } } else { ubatch.token = nullptr; } if (batch->embd) { - for (size_t i = 0; i < length; ++i) { - memcpy( - ubatch.embd + n_embd * (ubatch.n_tokens + i), - batch->embd + n_embd * ids[seq.offset + i], - n_embd * sizeof(float) - ); + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + memcpy( + ubatch.embd + n_embd * (ubatch.n_tokens + i), + batch->embd + n_embd * ids[seq.offset + i], + n_embd * sizeof(float) + ); + } + } else { + // simple split + ubatch.embd = batch->embd + seq.offset; } } else { ubatch.embd = nullptr; @@ -3163,8 +3173,13 @@ struct llama_sbatch { // from here on, the else branches are deprecated; // they are helpers for smoother batch API transition if (batch->pos) { - for (size_t i = 0; i < length; ++i) { - ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]]; + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]]; + } + } else { + // simple split + ubatch.pos = batch->pos + seq.offset; } } else { for (size_t i = 0; i < length; ++i) { @@ -3172,7 +3187,7 @@ struct llama_sbatch { ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1); } } - if (seq.n_seq_id > 0) { + if (ubatch.equal_seqs) { ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id; if (seq.seq_id) { ubatch.seq_id[ubatch.n_seqs] = seq.seq_id; @@ -3181,9 +3196,10 @@ struct llama_sbatch { ubatch.seq_id[ubatch.n_seqs] = &seq.all_seq_id; } } else { + // simple split if (batch->n_seq_id) { for (size_t i = 0; i < length; ++i) { - ubatch.n_seq_id[ubatch.n_seqs + i] = batch->n_seq_id[ids[seq.offset + i]]; + ubatch.n_seq_id = batch->n_seq_id + seq.offset; } } else { for (size_t i = 0; i < length; ++i) { @@ -3192,7 +3208,7 @@ struct llama_sbatch { } if (batch->seq_id) { for (size_t i = 0; i < length; ++i) { - ubatch.seq_id[ubatch.n_seqs + i] = batch->seq_id[ids[seq.offset + i]]; + ubatch.seq_id = batch->seq_id + seq.offset; } } else { for (size_t i = 0; i < length; ++i) { @@ -3201,11 +3217,19 @@ struct llama_sbatch { } } if (batch->logits) { - for (size_t i = 0; i < length; ++i) { - size_t id = ids[seq.offset + i]; - int8_t is_output = batch->logits[id]; - ubatch.output[ubatch.n_tokens + i] = is_output; - if (is_output) { out_ids.push_back(id); } + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + size_t id = ids[seq.offset + i]; + int8_t is_output = batch->logits[id]; + ubatch.output[ubatch.n_tokens + i] = is_output; + if (is_output) { out_ids.push_back(id); } + } + } else { + // simple split + ubatch.output = batch->logits + seq.offset; + for (size_t i = 0; i < length; ++i) { + if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); } + } } } else if (logits_all) { for (size_t i = 0; i < length; ++i) { @@ -3222,18 +3246,18 @@ struct llama_sbatch { } } if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) { - ubatch.n_seq_tokens = seq.n_seq_id > 0 ? length : 1; + ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1; } ubatch.n_tokens += length; - ubatch.n_seqs += seq.n_seq_id > 0 ? 1 : length; // virtual sequences for legacy splits + ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits seq.offset += length; seq.length -= length; n_tokens -= length; GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs); } - // legacy split, unknown number of sequences of unequal lengths - llama_ubatch split_slice(size_t n_ubatch) { + // simple split, unknown number of sequences of unequal lengths + llama_ubatch split_simple(size_t n_ubatch) { n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); ubatch.equal_seqs = false; @@ -3241,7 +3265,6 @@ struct llama_sbatch { llama_sbatch_seq & s = seq[0]; size_t length = s.length < n_ubatch ? s.length : n_ubatch; GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits - // TODO: reduce copies add_seq_to_ubatch(ubatch, s, length); } return ubatch; @@ -3254,7 +3277,7 @@ struct llama_sbatch { if (!seq.empty()) { size_t length = 0; size_t n_tokens_in_ubatch = 0; - GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with legacy splits + GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits // smallest first, because it's easier to split this way; // starting from the end to pop in constant time. for (size_t i = seq.size(); i-- > 0;) { @@ -3282,13 +3305,13 @@ struct llama_sbatch { if (!seq.empty()) { llama_sbatch_seq & s = seq[seq.size() - 1]; size_t length = s.length < n_ubatch ? s.length : n_ubatch; - GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with legacy splits + GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits add_seq_to_ubatch(ubatch, s, length); } return ubatch; } - void from_batch(const llama_batch & batch, const size_t n_embd, const bool legacy_split = false, const bool logits_all = false) { + void from_batch(const llama_batch & batch, const size_t n_embd, const bool simple_split = false, const bool logits_all = false) { GGML_ASSERT(batch.n_tokens >= 0); this->batch = &batch; this->n_embd = n_embd; @@ -3302,7 +3325,7 @@ struct llama_sbatch { for (size_t i = 0; i < n_tokens; ++i) { ids[i] = i; } - if (legacy_split) { + if (simple_split) { seq.resize(1); llama_sbatch_seq & s = seq[0]; s.n_seq_id = 0; @@ -13737,7 +13760,7 @@ static int llama_decode_internal( } lctx.sbatch.from_batch(batch_all, n_embd, - /* legacy_split */ rs_self.size == 0, + /* simple_split */ rs_self.size == 0, /* logits_all */ n_outputs == n_tokens_all); // reserve output buffer @@ -13749,7 +13772,7 @@ static int llama_decode_internal( while (lctx.sbatch.n_tokens > 0) { // TODO: deprecate slice splits in favor of equal splits // For now, only use equal splits for recurrent or hybrid model architectures - llama_ubatch u_batch = (rs_self.size > 0) ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_slice(n_ubatch); + llama_ubatch u_batch = (rs_self.size > 0) ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_simple(n_ubatch); const uint32_t n_tokens = u_batch.n_tokens; // count the outputs in this u_batch