cont : better indices

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-10-17 16:55:33 +03:00
parent 99c4a39bf1
commit 7899c67f7c
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -1814,12 +1814,12 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
// combine tokens with common prefix // combine tokens with common prefix
for (size_t i0 = 0; i0 < cur_p->size; ++i0) { for (size_t i0 = 0; i0 < cur_p->size; ++i0) {
for (size_t j0 = 0; j0 < cur_p->size; ++j0) { for (size_t i1 = 0; i1 < cur_p->size; ++i1) {
if (cur_p->data[i0].logit == -INFINITY) { if (cur_p->data[i0].logit == -INFINITY) {
break; break;
} }
if (i0 == j0 || cur_p->data[j0].logit == -INFINITY) { if (i0 == i1 || cur_p->data[i1].logit == -INFINITY) {
continue; continue;
} }
@ -1830,20 +1830,20 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
assert(len0 > 0); assert(len0 > 0);
} }
int len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[j0].id, ctx->buf1.data(), ctx->buf1.size(), 0, false); int len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
if (len1 < 0) { if (len1 < 0) {
ctx->buf1.resize(len1); ctx->buf1.resize(len1);
len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[j0].id, ctx->buf1.data(), ctx->buf1.size(), 0, false); len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
assert(len1 > 0); assert(len1 > 0);
} }
// token i0 is a prefix of token j0 // token i0 is a prefix of token i1
if (len0 > 0 && len0 <= len1 && memcmp(ctx->buf0.data(), ctx->buf1.data(), len0) == 0) { if (len0 > 0 && len0 <= len1 && memcmp(ctx->buf0.data(), ctx->buf1.data(), len0) == 0) {
int dst = i0; int dst = i0;
int src = j0; int src = i1;
// merge into the token with higher probability // merge into the token with higher probability
if (cur_p->data[j0].p > cur_p->data[i0].p) { if (cur_p->data[i1].p > cur_p->data[i0].p) {
std::swap(dst, src); std::swap(dst, src);
} }