mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-29 04:44:34 +00:00
llama : add pooling switch
This commit is contained in:
parent
9bbeb0f110
commit
e66da356a4
11
llama.cpp
11
llama.cpp
@ -8379,10 +8379,17 @@ static int llama_decode_internal(
|
||||
if (batch.logits[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
if (hparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
|
||||
switch (hparams.pooling_type) {
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*batch.seq_id[i][0])*sizeof(float), n_embd*sizeof(float));
|
||||
} else {
|
||||
break;
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*i)*sizeof(float), n_embd*sizeof(float));
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown pooling type");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user