mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 19:50:17 +00:00
llama : fix Mamba pooled embeddings with multiple sequences
Until the pooled embeddings are refactored to allow splitting across ubatches for causal embeddings, recurrent models can only process a single sequence per ubatch when calculating pooled embeddings.
This commit is contained in:
parent
652e9b0d61
commit
b264eddbb2
@ -15154,6 +15154,8 @@ static int llama_decode_internal(
|
|||||||
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
||||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||||
|
|
||||||
|
lctx.embd_seq.clear();
|
||||||
|
|
||||||
// count outputs
|
// count outputs
|
||||||
if (batch_all.logits && !embd_pooled) {
|
if (batch_all.logits && !embd_pooled) {
|
||||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||||
@ -15177,8 +15179,19 @@ static int llama_decode_internal(
|
|||||||
};
|
};
|
||||||
|
|
||||||
while (lctx.sbatch.n_tokens > 0) {
|
while (lctx.sbatch.n_tokens > 0) {
|
||||||
// For now, only use equal splits for recurrent model architectures
|
llama_ubatch ubatch;
|
||||||
llama_ubatch ubatch = kv_self.recurrent ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_simple(n_ubatch);
|
if (kv_self.recurrent) {
|
||||||
|
if (embd_pooled) {
|
||||||
|
// Pooled embeddings cannot be split across ubatches (yet)
|
||||||
|
ubatch = lctx.sbatch.split_seq(n_ubatch);
|
||||||
|
} else {
|
||||||
|
// recurrent model architectures are easier to implement
|
||||||
|
// with equal-length sequences
|
||||||
|
ubatch = lctx.sbatch.split_equal(n_ubatch);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ubatch = lctx.sbatch.split_simple(n_ubatch);
|
||||||
|
}
|
||||||
const uint32_t n_tokens = ubatch.n_tokens;
|
const uint32_t n_tokens = ubatch.n_tokens;
|
||||||
|
|
||||||
// count the outputs in this u_batch
|
// count the outputs in this u_batch
|
||||||
@ -15316,9 +15329,8 @@ static int llama_decode_internal(
|
|||||||
case LLAMA_POOLING_TYPE_CLS:
|
case LLAMA_POOLING_TYPE_CLS:
|
||||||
case LLAMA_POOLING_TYPE_LAST:
|
case LLAMA_POOLING_TYPE_LAST:
|
||||||
{
|
{
|
||||||
// extract sequence embeddings
|
// extract sequence embeddings (cleared before processing each batch)
|
||||||
auto & embd_seq_out = lctx.embd_seq;
|
auto & embd_seq_out = lctx.embd_seq;
|
||||||
embd_seq_out.clear();
|
|
||||||
|
|
||||||
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
||||||
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
||||||
|
Loading…
Reference in New Issue
Block a user