llama : sanitize invalid tokens (#9357)

* common : do not add null tokens during warmup

ggml-ci

* llama : check that the input tokens are valid

ggml-ci

* tests : fix batch size of bert model

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-09-08 00:33:13 +03:00 committed by GitHub
parent e536426ded
commit faf69d4237
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 26 additions and 4 deletions

View File

@ -2690,10 +2690,15 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
llama_token bos = llama_token_bos(model); llama_token bos = llama_token_bos(model);
llama_token eos = llama_token_eos(model); llama_token eos = llama_token_eos(model);
// some models (e.g. T5) don't have a BOS token // some models (e.g. T5) don't have a BOS token
if (bos != -1) { if (bos != LLAMA_TOKEN_NULL) {
tmp.push_back(bos); tmp.push_back(bos);
} }
tmp.push_back(eos); if (eos != LLAMA_TOKEN_NULL) {
tmp.push_back(eos);
}
if (tmp.empty()) {
tmp.push_back(0);
}
if (llama_model_has_encoder(model)) { if (llama_model_has_encoder(model)) {
llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0)); llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0));

View File

@ -9,8 +9,11 @@ Feature: llama.cpp server
And a model alias bert-bge-small And a model alias bert-bge-small
And 42 as server seed And 42 as server seed
And 2 slots And 2 slots
And 1024 as batch size # the bert-bge-small model has context size of 512
And 1024 as ubatch size # since the generated prompts are as big as the batch size, we need to set the batch size to 512
# ref: https://huggingface.co/BAAI/bge-small-en-v1.5/blob/5c38ec7c405ec4b44b94cc5a9bb96e735b38267a/config.json#L20
And 512 as batch size
And 512 as ubatch size
And 2048 KV cache size And 2048 KV cache size
And embeddings extraction And embeddings extraction
Then the server is starting Then the server is starting

View File

@ -16066,6 +16066,13 @@ static int llama_decode_internal(
return -1; return -1;
} }
for (uint32_t i = 0; i < n_tokens_all; ++i) {
if (batch_all.token[i] < 0) {
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d", __func__, i, batch_all.token[i]);
return -1;
}
}
const auto & model = lctx.model; const auto & model = lctx.model;
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
const auto & cparams = lctx.cparams; const auto & cparams = lctx.cparams;
@ -16358,6 +16365,13 @@ static int llama_encode_internal(
return -1; return -1;
} }
for (uint32_t i = 0; i < n_tokens; ++i) {
if (batch.token[i] < 0) {
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d", __func__, i, batch.token[i]);
return -1;
}
}
const auto & model = lctx.model; const auto & model = lctx.model;
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
const auto & cparams = lctx.cparams; const auto & cparams = lctx.cparams;