mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-30 21:34:36 +00:00
llama : add pooling switch
This commit is contained in:
parent
9bbeb0f110
commit
e66da356a4
43
llama.cpp
43
llama.cpp
@ -8113,7 +8113,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||||||
|
|
||||||
for (int i = 0; i < n_tokens; ++i) {
|
for (int i = 0; i < n_tokens; ++i) {
|
||||||
const llama_seq_id seq_id = batch.seq_id[i][0];
|
const llama_seq_id seq_id = batch.seq_id[i][0];
|
||||||
const llama_pos pos = batch.pos[i];
|
const llama_pos pos = batch.pos[i];
|
||||||
if (pos == 0) {
|
if (pos == 0) {
|
||||||
data[seq_id] = i;
|
data[seq_id] = i;
|
||||||
}
|
}
|
||||||
@ -8379,10 +8379,17 @@ static int llama_decode_internal(
|
|||||||
if (batch.logits[i] == 0) {
|
if (batch.logits[i] == 0) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (hparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
|
switch (hparams.pooling_type) {
|
||||||
ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*batch.seq_id[i][0])*sizeof(float), n_embd*sizeof(float));
|
case LLAMA_POOLING_TYPE_CLS:
|
||||||
} else {
|
ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*batch.seq_id[i][0])*sizeof(float), n_embd*sizeof(float));
|
||||||
ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*i)*sizeof(float), n_embd*sizeof(float));
|
break;
|
||||||
|
case LLAMA_POOLING_TYPE_MEAN:
|
||||||
|
case LLAMA_POOLING_TYPE_NONE:
|
||||||
|
ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*i)*sizeof(float), n_embd*sizeof(float));
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(false && "unknown pooling type");
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -8680,19 +8687,19 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
|
|||||||
GGML_ASSERT(llama_is_byte_token(vocab, id));
|
GGML_ASSERT(llama_is_byte_token(vocab, id));
|
||||||
const auto& token_data = vocab.id_to_token.at(id);
|
const auto& token_data = vocab.id_to_token.at(id);
|
||||||
switch (llama_vocab_get_type(vocab)) {
|
switch (llama_vocab_get_type(vocab)) {
|
||||||
case LLAMA_VOCAB_TYPE_SPM: {
|
case LLAMA_VOCAB_TYPE_SPM: {
|
||||||
auto buf = token_data.text.substr(3, 2);
|
auto buf = token_data.text.substr(3, 2);
|
||||||
return strtol(buf.c_str(), NULL, 16);
|
return strtol(buf.c_str(), NULL, 16);
|
||||||
}
|
}
|
||||||
case LLAMA_VOCAB_TYPE_BPE: {
|
case LLAMA_VOCAB_TYPE_BPE: {
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
return unicode_to_bytes_bpe(token_data.text);
|
return unicode_to_bytes_bpe(token_data.text);
|
||||||
}
|
}
|
||||||
case LLAMA_VOCAB_TYPE_WPM: {
|
case LLAMA_VOCAB_TYPE_WPM: {
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user