llama : fix T5 segfault again

This commit is contained in:
Francis Couture-Harpin 2024-08-20 21:37:43 -04:00
parent 702e1995a1
commit 652e9b0d61

View File

@ -15482,12 +15482,13 @@ static int llama_encode_internal(
float * embd_out = lctx.embd_enc.data(); float * embd_out = lctx.embd_enc.data();
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float)); ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
// remember the sequence ids used during the encoding - needed for cross attention later // remember the sequence ids used during the encoding - needed for cross attention later
lctx.seq_ids_enc.resize(n_tokens); lctx.seq_ids_enc.resize(n_tokens);
for (uint32_t i = 0; i < n_tokens; i++) { for (uint32_t i = 0; i < n_tokens; i++) {
for (int s = 0; s < batch.n_seq_id[i]; s++) { for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
llama_seq_id seq_id = batch.seq_id[i][s]; llama_seq_id seq_id = ubatch.seq_id[i][s];
lctx.seq_ids_enc[i].insert(seq_id); lctx.seq_ids_enc[i].insert(seq_id);
} }
} }
@ -15512,8 +15513,10 @@ static int llama_encode_internal(
auto & embd_seq_out = lctx.embd_seq; auto & embd_seq_out = lctx.embd_seq;
embd_seq_out.clear(); embd_seq_out.clear();
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
for (uint32_t i = 0; i < n_tokens; i++) { for (uint32_t i = 0; i < n_tokens; i++) {
const llama_seq_id seq_id = batch.seq_id[i][0]; const llama_seq_id seq_id = ubatch.seq_id[i][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue; continue;
} }