llama : add pooling switch

This commit is contained in:
Georgi Gerganov 2024-03-04 14:06:33 +02:00
parent 9bbeb0f110
commit e66da356a4
No known key found for this signature in database
GPG Key ID: BF970631944C16B7

View File

@ -8379,10 +8379,17 @@ static int llama_decode_internal(
if (batch.logits[i] == 0) { if (batch.logits[i] == 0) {
continue; continue;
} }
if (hparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { switch (hparams.pooling_type) {
case LLAMA_POOLING_TYPE_CLS:
ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*batch.seq_id[i][0])*sizeof(float), n_embd*sizeof(float)); ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*batch.seq_id[i][0])*sizeof(float), n_embd*sizeof(float));
} else { break;
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_NONE:
ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*i)*sizeof(float), n_embd*sizeof(float)); ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*i)*sizeof(float), n_embd*sizeof(float));
break;
default:
GGML_ASSERT(false && "unknown pooling type");
break;
} }
} }
} }