mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-08 01:31:45 +00:00
llama : fix batch split output count for embeddings
This commit is contained in:
parent
5d3c7b9585
commit
72eea49224
@ -13730,7 +13730,9 @@ static int llama_decode_internal(
|
|||||||
n_outputs = 1;
|
n_outputs = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
lctx.sbatch.from_batch(batch_all, n_embd, /* legacy_split */ rs_self.size == 0, lctx.logits_all);
|
lctx.sbatch.from_batch(batch_all, n_embd,
|
||||||
|
/* legacy_split */ rs_self.size == 0,
|
||||||
|
/* logits_all */ n_outputs == n_tokens_all);
|
||||||
|
|
||||||
// reserve output buffer
|
// reserve output buffer
|
||||||
if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
|
if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
|
||||||
@ -13740,6 +13742,7 @@ static int llama_decode_internal(
|
|||||||
|
|
||||||
while (lctx.sbatch.n_tokens > 0) {
|
while (lctx.sbatch.n_tokens > 0) {
|
||||||
// TODO: deprecate slice splits in favor of equal splits
|
// TODO: deprecate slice splits in favor of equal splits
|
||||||
|
// For now, only use equal splits for recurrent or hybrid model architectures
|
||||||
llama_ubatch u_batch = (rs_self.size > 0) ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_slice(n_ubatch);
|
llama_ubatch u_batch = (rs_self.size > 0) ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_slice(n_ubatch);
|
||||||
const uint32_t n_tokens = u_batch.n_tokens;
|
const uint32_t n_tokens = u_batch.n_tokens;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user