llama : fix Mamba pooled embeddings with multiple sequences

Until the pooled embeddings are refactored to allow splitting
across ubatches for causal embeddings,
recurrent models can only process a single sequence per ubatch
when calculating pooled embeddings.
This commit is contained in:
Francis Couture-Harpin 2024-08-20 23:29:48 -04:00
parent 652e9b0d61
commit b264eddbb2

View File

@ -15154,6 +15154,8 @@ static int llama_decode_internal(
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
lctx.embd_seq.clear();
// count outputs
if (batch_all.logits && !embd_pooled) {
for (uint32_t i = 0; i < n_tokens_all; ++i) {
@ -15177,8 +15179,19 @@ static int llama_decode_internal(
};
while (lctx.sbatch.n_tokens > 0) {
// For now, only use equal splits for recurrent model architectures
llama_ubatch ubatch = kv_self.recurrent ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_simple(n_ubatch);
llama_ubatch ubatch;
if (kv_self.recurrent) {
if (embd_pooled) {
// Pooled embeddings cannot be split across ubatches (yet)
ubatch = lctx.sbatch.split_seq(n_ubatch);
} else {
// recurrent model architectures are easier to implement
// with equal-length sequences
ubatch = lctx.sbatch.split_equal(n_ubatch);
}
} else {
ubatch = lctx.sbatch.split_simple(n_ubatch);
}
const uint32_t n_tokens = ubatch.n_tokens;
// count the outputs in this u_batch
@ -15316,9 +15329,8 @@ static int llama_decode_internal(
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_LAST:
{
// extract sequence embeddings
// extract sequence embeddings (cleared before processing each batch)
auto & embd_seq_out = lctx.embd_seq;
embd_seq_out.clear();
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0];