llama : infill sampling handle very long tokens (#9924)

* llama : infill sampling handle very long tokens

ggml-ci

* cont : better indices

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-10-17 22:32:47 +03:00 committed by GitHub
parent 3752217ed5
commit 99bd4ac28c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 35 additions and 43 deletions

View File

@ -953,12 +953,6 @@ extern "C" {
int32_t lstrip, int32_t lstrip,
bool special); bool special);
// check if token0 is contained as a prefix in token1
LLAMA_API bool llama_token_is_prefix(
const struct llama_model * model,
llama_token token0,
llama_token token1);
/// @details Convert the provided tokens into text (inverse of llama_tokenize()). /// @details Convert the provided tokens into text (inverse of llama_tokenize()).
/// @param text The char pointer must be large enough to hold the resulting text. /// @param text The char pointer must be large enough to hold the resulting text.
/// @return Returns the number of chars/bytes on success, no more than text_len_max. /// @return Returns the number of chars/bytes on success, no more than text_len_max.

View File

@ -1745,6 +1745,9 @@ struct llama_sampler * llama_sampler_init_logit_bias(
struct llama_sampler_infill { struct llama_sampler_infill {
const struct llama_vocab * vocab; const struct llama_vocab * vocab;
std::vector<char> buf0;
std::vector<char> buf1;
}; };
static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) { static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
@ -1810,27 +1813,44 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
size_t n_combined = 0; GGML_UNUSED(n_combined); size_t n_combined = 0; GGML_UNUSED(n_combined);
// combine tokens with common prefix // combine tokens with common prefix
for (size_t i = 0; i < cur_p->size; ++i) { for (size_t i0 = 0; i0 < cur_p->size; ++i0) {
for (size_t j = 0; j < cur_p->size; ++j) { for (size_t i1 = 0; i1 < cur_p->size; ++i1) {
if (cur_p->data[i].logit == -INFINITY) { if (cur_p->data[i0].logit == -INFINITY) {
break; break;
} }
if (i == j || cur_p->data[j].logit == -INFINITY) { if (i0 == i1 || cur_p->data[i1].logit == -INFINITY) {
continue; continue;
} }
if (llama_token_is_prefix_impl(*ctx->vocab, cur_p->data[i].id, cur_p->data[j].id)) { int len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
if (cur_p->data[i].p > cur_p->data[j].p) { if (len0 < 0) {
cur_p->data[i].p += cur_p->data[j].p; ctx->buf0.resize(len0);
cur_p->data[j].logit = -INFINITY; len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
cur_p->data[j].p = 0.0f; assert(len0 > 0);
} else {
cur_p->data[j].p += cur_p->data[i].p;
cur_p->data[i].logit = -INFINITY;
cur_p->data[i].p = 0.0f;
} }
int len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
if (len1 < 0) {
ctx->buf1.resize(len1);
len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
assert(len1 > 0);
}
// token i0 is a prefix of token i1
if (len0 > 0 && len0 <= len1 && memcmp(ctx->buf0.data(), ctx->buf1.data(), len0) == 0) {
int dst = i0;
int src = i1;
// merge into the token with higher probability
if (cur_p->data[i1].p > cur_p->data[i0].p) {
std::swap(dst, src);
}
cur_p->data[dst].p += cur_p->data[src].p;
cur_p->data[src].logit = -INFINITY;
cur_p->data[src].p = 0.0f;
n_combined++; n_combined++;
} }
} }
@ -1936,6 +1956,8 @@ struct llama_sampler * llama_sampler_init_infill_impl(
/* .iface = */ &llama_sampler_infill_i, /* .iface = */ &llama_sampler_infill_i,
/* .ctx = */ new llama_sampler_infill { /* .ctx = */ new llama_sampler_infill {
/* .vocab = */ &vocab, /* .vocab = */ &vocab,
/* .buf0 = */ std::vector<char>(512),
/* .buf1 = */ std::vector<char>(512),
}, },
}; };
} }

View File

@ -1858,23 +1858,6 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
return 0; return 0;
} }
bool llama_token_is_prefix_impl(
const struct llama_vocab & vocab,
llama_token token0,
llama_token token1) {
char text_buf_0[128];
char text_buf_1[128];
const int32_t len0 = llama_token_to_piece_impl(vocab, token0, text_buf_0, sizeof(text_buf_0) - 1, 0, false);
const int32_t len1 = llama_token_to_piece_impl(vocab, token1, text_buf_1, sizeof(text_buf_1) - 1, 0, false);
if (len0 <= 0 || len1 <= 0) {
return false;
}
return len0 <= len1 && memcmp(text_buf_0, text_buf_1, len0) == 0;
}
int32_t llama_detokenize_impl( int32_t llama_detokenize_impl(
const struct llama_vocab & vocab, const struct llama_vocab & vocab,
const llama_token * tokens, const llama_token * tokens,

View File

@ -21466,13 +21466,6 @@ int32_t llama_token_to_piece(
return llama_token_to_piece_impl(model->vocab, token, buf, length, lstrip, special); return llama_token_to_piece_impl(model->vocab, token, buf, length, lstrip, special);
} }
bool llama_token_is_prefix(
const struct llama_model * model,
llama_token token0,
llama_token token1) {
return llama_token_is_prefix_impl(model->vocab, token0, token1);
}
int32_t llama_detokenize( int32_t llama_detokenize(
const struct llama_model * model, const struct llama_model * model,
const llama_token * tokens, const llama_token * tokens,