llama : logits_all has priority over batch->logits

Otherwise, the server embeddings tests failed.
This was likely an existing problem but was only detected here
because of an additional assertion.
This commit is contained in:
Francis Couture-Harpin 2024-07-17 01:14:26 -04:00
parent 2e4adb47ec
commit 7b7db0bbee

View File

@ -2898,7 +2898,12 @@ struct llama_sbatch {
}
}
}
if (batch->logits) {
if (logits_all) {
for (size_t i = 0; i < length; ++i) {
ubatch.output[ubatch.n_tokens + i] = 1;
out_ids.push_back(ids[seq.offset + i]);
}
} else if (batch->logits) {
if (ubatch.equal_seqs) {
for (size_t i = 0; i < length; ++i) {
size_t id = ids[seq.offset + i];
@ -2913,11 +2918,6 @@ struct llama_sbatch {
if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
}
}
} else if (logits_all) {
for (size_t i = 0; i < length; ++i) {
ubatch.output[ubatch.n_tokens + i] = 1;
out_ids.push_back(ids[seq.offset + i]);
}
} else {
// only get last output
for (size_t i = 0; i < length; ++i) {
@ -15088,7 +15088,7 @@ static int llama_decode_internal(
};
while (lctx.sbatch.n_tokens > 0) {
// For now, only use equal splits for recurrent or hybrid model architectures
// For now, only use equal splits for recurrent model architectures
llama_ubatch u_batch = kv_self.recurrent ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_simple(n_ubatch);
const uint32_t n_tokens = u_batch.n_tokens;