diff --git a/src/llama.cpp b/src/llama.cpp index ce59d006e..4fc3359b9 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2898,7 +2898,12 @@ struct llama_sbatch { } } } - if (batch->logits) { + if (logits_all) { + for (size_t i = 0; i < length; ++i) { + ubatch.output[ubatch.n_tokens + i] = 1; + out_ids.push_back(ids[seq.offset + i]); + } + } else if (batch->logits) { if (ubatch.equal_seqs) { for (size_t i = 0; i < length; ++i) { size_t id = ids[seq.offset + i]; @@ -2913,11 +2918,6 @@ struct llama_sbatch { if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); } } } - } else if (logits_all) { - for (size_t i = 0; i < length; ++i) { - ubatch.output[ubatch.n_tokens + i] = 1; - out_ids.push_back(ids[seq.offset + i]); - } } else { // only get last output for (size_t i = 0; i < length; ++i) { @@ -15088,7 +15088,7 @@ static int llama_decode_internal( }; while (lctx.sbatch.n_tokens > 0) { - // For now, only use equal splits for recurrent or hybrid model architectures + // For now, only use equal splits for recurrent model architectures llama_ubatch u_batch = kv_self.recurrent ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_simple(n_ubatch); const uint32_t n_tokens = u_batch.n_tokens;