From b264eddbb26c695d50d04c37a5b9bb91181bc551 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Tue, 20 Aug 2024 23:29:48 -0400 Subject: [PATCH] llama : fix Mamba pooled embeddings with multiple sequences Until the pooled embeddings are refactored to allow splitting across ubatches for causal embeddings, recurrent models can only process a single sequence per ubatch when calculating pooled embeddings. --- src/llama.cpp | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index d4f9ba7e8..bd319e62c 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -15154,6 +15154,8 @@ static int llama_decode_internal( // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; + lctx.embd_seq.clear(); + // count outputs if (batch_all.logits && !embd_pooled) { for (uint32_t i = 0; i < n_tokens_all; ++i) { @@ -15177,8 +15179,19 @@ static int llama_decode_internal( }; while (lctx.sbatch.n_tokens > 0) { - // For now, only use equal splits for recurrent model architectures - llama_ubatch ubatch = kv_self.recurrent ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_simple(n_ubatch); + llama_ubatch ubatch; + if (kv_self.recurrent) { + if (embd_pooled) { + // Pooled embeddings cannot be split across ubatches (yet) + ubatch = lctx.sbatch.split_seq(n_ubatch); + } else { + // recurrent model architectures are easier to implement + // with equal-length sequences + ubatch = lctx.sbatch.split_equal(n_ubatch); + } + } else { + ubatch = lctx.sbatch.split_simple(n_ubatch); + } const uint32_t n_tokens = ubatch.n_tokens; // count the outputs in this u_batch @@ -15316,9 +15329,8 @@ static int llama_decode_internal( case LLAMA_POOLING_TYPE_CLS: case LLAMA_POOLING_TYPE_LAST: { - // extract sequence embeddings + // extract sequence embeddings (cleared before processing each batch) auto & embd_seq_out = lctx.embd_seq; - embd_seq_out.clear(); for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { const llama_seq_id seq_id = ubatch.seq_id[s][0];