llama : fix Mamba pooled embeddings with multiple sequences

Until the pooled embeddings are refactored to allow splitting
across ubatches for causal embeddings,
recurrent models can only process a single sequence per ubatch
when calculating pooled embeddings.
This commit is contained in:
Francis Couture-Harpin 2024-08-20 23:29:48 -04:00
parent 652e9b0d61
commit b264eddbb2

View File

@ -15154,6 +15154,8 @@ static int llama_decode_internal(
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
lctx.embd_seq.clear();
// count outputs // count outputs
if (batch_all.logits && !embd_pooled) { if (batch_all.logits && !embd_pooled) {
for (uint32_t i = 0; i < n_tokens_all; ++i) { for (uint32_t i = 0; i < n_tokens_all; ++i) {
@ -15177,8 +15179,19 @@ static int llama_decode_internal(
}; };
while (lctx.sbatch.n_tokens > 0) { while (lctx.sbatch.n_tokens > 0) {
// For now, only use equal splits for recurrent model architectures llama_ubatch ubatch;
llama_ubatch ubatch = kv_self.recurrent ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_simple(n_ubatch); if (kv_self.recurrent) {
if (embd_pooled) {
// Pooled embeddings cannot be split across ubatches (yet)
ubatch = lctx.sbatch.split_seq(n_ubatch);
} else {
// recurrent model architectures are easier to implement
// with equal-length sequences
ubatch = lctx.sbatch.split_equal(n_ubatch);
}
} else {
ubatch = lctx.sbatch.split_simple(n_ubatch);
}
const uint32_t n_tokens = ubatch.n_tokens; const uint32_t n_tokens = ubatch.n_tokens;
// count the outputs in this u_batch // count the outputs in this u_batch
@ -15316,9 +15329,8 @@ static int llama_decode_internal(
case LLAMA_POOLING_TYPE_CLS: case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_LAST: case LLAMA_POOLING_TYPE_LAST:
{ {
// extract sequence embeddings // extract sequence embeddings (cleared before processing each batch)
auto & embd_seq_out = lctx.embd_seq; auto & embd_seq_out = lctx.embd_seq;
embd_seq_out.clear();
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0]; const llama_seq_id seq_id = ubatch.seq_id[s][0];