llama : add infill sampler (#9896)
Some checks are pending
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full-cuda.Dockerfile platforms:linux/amd64 tag:full-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full-musa.Dockerfile platforms:linux/amd64 tag:full-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full.Dockerfile platforms:linux/amd64,linux/arm64 tag:full]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-cuda.Dockerfile platforms:linux/amd64 tag:light-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-intel.Dockerfile platforms:linux/amd64 tag:light-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-musa.Dockerfile platforms:linux/amd64 tag:light-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli.Dockerfile platforms:linux/amd64,linux/arm64 tag:light]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-cuda.Dockerfile platforms:linux/amd64 tag:server-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-intel.Dockerfile platforms:linux/amd64 tag:server-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-musa.Dockerfile platforms:linux/amd64 tag:server-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server.Dockerfile platforms:linux/amd64,linux/arm64 tag:server]) (push) Waiting to run
Nix CI / nix-eval (macos-latest) (push) Waiting to run
Nix CI / nix-eval (ubuntu-latest) (push) Waiting to run
Nix CI / nix-build (macos-latest) (push) Waiting to run
Nix CI / nix-build (ubuntu-latest) (push) Waiting to run
flake8 Lint / Lint (push) Waiting to run

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-10-15 16:35:33 +03:00 committed by GitHub
parent 223c25a72f
commit 755a9b2bf0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 300 additions and 29 deletions

View File

@ -91,7 +91,7 @@ enum common_sampler_type {
COMMON_SAMPLER_TYPE_TYPICAL_P = 5, COMMON_SAMPLER_TYPE_TYPICAL_P = 5,
COMMON_SAMPLER_TYPE_TEMPERATURE = 6, COMMON_SAMPLER_TYPE_TEMPERATURE = 6,
COMMON_SAMPLER_TYPE_XTC = 7, COMMON_SAMPLER_TYPE_XTC = 7,
COMMON_SAMPLER_TYPE_INFILL = 8,
}; };
// dimensionality reduction methods, used by cvector-generator // dimensionality reduction methods, used by cvector-generator
@ -136,7 +136,7 @@ struct common_sampler_params {
COMMON_SAMPLER_TYPE_TOP_P, COMMON_SAMPLER_TYPE_TOP_P,
COMMON_SAMPLER_TYPE_MIN_P, COMMON_SAMPLER_TYPE_MIN_P,
COMMON_SAMPLER_TYPE_XTC, COMMON_SAMPLER_TYPE_XTC,
COMMON_SAMPLER_TYPE_TEMPERATURE COMMON_SAMPLER_TYPE_TEMPERATURE,
}; };
std::string grammar; // optional BNF-like grammar to constrain sampling std::string grammar; // optional BNF-like grammar to constrain sampling

View File

@ -196,6 +196,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
case COMMON_SAMPLER_TYPE_TEMPERATURE: case COMMON_SAMPLER_TYPE_TEMPERATURE:
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
break; break;
case COMMON_SAMPLER_TYPE_INFILL:
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
break;
default: default:
GGML_ASSERT(false && "unknown sampler type"); GGML_ASSERT(false && "unknown sampler type");
} }
@ -376,6 +379,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_MIN_P: return 'm'; case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't'; case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
case COMMON_SAMPLER_TYPE_XTC: return 'x'; case COMMON_SAMPLER_TYPE_XTC: return 'x';
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
default : return '?'; default : return '?';
} }
} }
@ -389,6 +393,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p"; case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature"; case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
case COMMON_SAMPLER_TYPE_XTC: return "xtc"; case COMMON_SAMPLER_TYPE_XTC: return "xtc";
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
default : return ""; default : return "";
} }
} }
@ -402,6 +407,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
{ "tfs_z", COMMON_SAMPLER_TYPE_TFS_Z }, { "tfs_z", COMMON_SAMPLER_TYPE_TFS_Z },
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE }, { "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
{ "xtc", COMMON_SAMPLER_TYPE_XTC }, { "xtc", COMMON_SAMPLER_TYPE_XTC },
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
}; };
// since samplers names are written multiple ways // since samplers names are written multiple ways
@ -448,7 +454,8 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC } { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
}; };
std::vector<common_sampler_type> samplers; std::vector<common_sampler_type> samplers;

View File

@ -569,7 +569,8 @@ int main(int argc, char ** argv) {
if (!params.ctx_shift){ if (!params.ctx_shift){
LOG_DBG("\n\n%s: context full and context shift is disabled => stopping\n", __func__); LOG_DBG("\n\n%s: context full and context shift is disabled => stopping\n", __func__);
break; break;
} else { }
if (params.n_predict == -2) { if (params.n_predict == -2) {
LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict); LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
break; break;
@ -593,7 +594,6 @@ int main(int argc, char ** argv) {
LOG_DBG("clear session path\n"); LOG_DBG("clear session path\n");
path_session.clear(); path_session.clear();
} }
}
} else { } else {
// context extension via Self-Extend // context extension via Self-Extend
while (n_past >= ga_i + ga_w) { while (n_past >= ga_i + ga_w) {

View File

@ -953,6 +953,12 @@ extern "C" {
int32_t lstrip, int32_t lstrip,
bool special); bool special);
// check if token0 is contained as a prefix in token1
LLAMA_API bool llama_token_is_prefix(
const struct llama_model * model,
llama_token token0,
llama_token token1);
/// @details Convert the provided tokens into text (inverse of llama_tokenize()). /// @details Convert the provided tokens into text (inverse of llama_tokenize()).
/// @param text The char pointer must be large enough to hold the resulting text. /// @param text The char pointer must be large enough to hold the resulting text.
/// @return Returns the number of chars/bytes on success, no more than text_len_max. /// @return Returns the number of chars/bytes on success, no more than text_len_max.
@ -1148,6 +1154,28 @@ extern "C" {
int32_t n_logit_bias, int32_t n_logit_bias,
const llama_logit_bias * logit_bias); const llama_logit_bias * logit_bias);
// this sampler is meant to be used for fill-in-the-middle infilling
// it's supposed to be used after top_k + top_p sampling
//
// 1. if the sum of the EOG probs times the number of candidates is higher than the sum of the other probs -> pick EOG
// 2. combine probs of tokens that have the same prefix
//
// example:
//
// - before:
// "hel": 0.5
// "hell": 0.2
// "hello": 0.1
// "dummy": 0.1
//
// - after:
// "hel": 0.8
// "dummy": 0.1
//
// 3. discard non-EOG tokens with low prob
// 4. if no tokens are left -> pick EOT
//
LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model);
// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise // Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl); LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);

View File

@ -1739,6 +1739,207 @@ struct llama_sampler * llama_sampler_init_logit_bias(
}; };
} }
// infill
//#define GGML_DEBUG_SAMPLER_INFILL
struct llama_sampler_infill {
const struct llama_vocab * vocab;
};
static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
return "infill";
}
static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_infill *) smpl->ctx;
llama_sampler_softmax_impl(cur_p);
#if defined(GGML_DEBUG_SAMPLER_INFILL)
#define LOG_DBG_CUR LLAMA_LOG_DEBUG
#else
#define LOG_DBG_CUR(...)
#endif
for (size_t i = 0; i < cur_p->size; ++i) {
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
}
float p_txt_sum = 0.0f;
float p_eog_sum = 0.0f;
for (size_t i = 0; i < cur_p->size; ++i) {
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
p_eog_sum += cur_p->data[i].p;
} else {
p_txt_sum += cur_p->data[i].p;
}
}
const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED(rat);
LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size);
if (3*p_eog_sum*cur_p->size > p_txt_sum) {
LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum);
// keep just the EOG tokens
const auto size_org = cur_p->size;
cur_p->size = 0;
float p_sum = 0.0f;
for (size_t i = 0; i < size_org; ++i) {
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
p_sum += cur_p->data[i].p;
cur_p->data[cur_p->size++] = cur_p->data[i];
}
}
// normalize probs
for (size_t i = 0; i < cur_p->size; ++i) {
cur_p->data[i].p /= p_sum;
}
return;
}
size_t n_combined = 0; GGML_UNUSED(n_combined);
// combine tokens with common prefix
for (size_t i = 0; i < cur_p->size; ++i) {
for (size_t j = 0; j < cur_p->size; ++j) {
if (cur_p->data[i].logit == -INFINITY) {
break;
}
if (i == j || cur_p->data[j].logit == -INFINITY) {
continue;
}
if (llama_token_is_prefix_impl(*ctx->vocab, cur_p->data[i].id, cur_p->data[j].id)) {
if (cur_p->data[i].p > cur_p->data[j].p) {
cur_p->data[i].p += cur_p->data[j].p;
cur_p->data[j].logit = -INFINITY;
cur_p->data[j].p = 0.0f;
} else {
cur_p->data[j].p += cur_p->data[i].p;
cur_p->data[i].logit = -INFINITY;
cur_p->data[i].p = 0.0f;
}
n_combined++;
}
}
}
size_t n_non_eog = 0;
size_t size_org = cur_p->size;
float p_sum = 0.0f;
float thold = 0.2f;
cur_p->size = 0;
LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);
for (size_t i = 0; i < size_org; ++i) {
const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
if (cur_p->data[i].p < thold && !is_eog) {
continue;
}
if (!is_eog) {
++n_non_eog;
}
p_sum += cur_p->data[i].p;
// keep this token
cur_p->data[cur_p->size++] = cur_p->data[i];
}
LOG_DBG_CUR("%s: n_non_eog = %zu\n", __func__, n_non_eog);
// if no non-EOG tokens are left -> reduce cur_p to single EOT token
if (n_non_eog == 0) {
cur_p->size = 1;
cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab);
cur_p->data[0].logit = 1.0f;
return;
}
// normalize probs
for (size_t i = 0; i < cur_p->size; ++i) {
cur_p->data[i].p /= p_sum;
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
}
size_org = cur_p->size;
p_sum = 0.0f;
thold = 1.0/(n_non_eog + 1);
cur_p->size = 0;
LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);
for (size_t i = 0; i < size_org; ++i) {
const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
if (cur_p->data[i].p < thold && !is_eog) {
continue;
}
p_sum += cur_p->data[i].p;
cur_p->data[cur_p->size++] = cur_p->data[i];
}
// normalize probs
for (size_t i = 0; i < cur_p->size; ++i) {
cur_p->data[i].p /= p_sum;
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
}
#undef LOG_DBG_CUR
}
static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
return llama_sampler_init_infill_impl(*ctx->vocab);
}
static void llama_sampler_infill_free(struct llama_sampler * smpl) {
delete (llama_sampler_infill *) smpl->ctx;
}
static struct llama_sampler_i llama_sampler_infill_i = {
/* .name = */ llama_sampler_infill_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_infill_apply,
/* .reset = */ nullptr,
/* .clone = */ llama_sampler_infill_clone,
/* .free = */ llama_sampler_infill_free,
};
struct llama_sampler * llama_sampler_init_infill_impl(
const struct llama_vocab & vocab) {
return new llama_sampler {
/* .iface = */ &llama_sampler_infill_i,
/* .ctx = */ new llama_sampler_infill {
/* .vocab = */ &vocab,
},
};
}
// utils // utils
uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) { uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {

View File

@ -4,8 +4,6 @@
#include "llama-grammar.h" #include "llama-grammar.h"
#include <unordered_map>
struct llama_vocab; struct llama_vocab;
struct llama_grammar; struct llama_grammar;
@ -27,3 +25,6 @@ struct llama_sampler * llama_sampler_init_grammar_impl(
const struct llama_vocab & vocab, const struct llama_vocab & vocab,
const char * grammar_str, const char * grammar_str,
const char * grammar_root); const char * grammar_root);
struct llama_sampler * llama_sampler_init_infill_impl(
const struct llama_vocab & vocab);

View File

@ -1858,6 +1858,23 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
return 0; return 0;
} }
bool llama_token_is_prefix_impl(
const struct llama_vocab & vocab,
llama_token token0,
llama_token token1) {
char text_buf_0[128];
char text_buf_1[128];
const int32_t len0 = llama_token_to_piece_impl(vocab, token0, text_buf_0, sizeof(text_buf_0) - 1, 0, false);
const int32_t len1 = llama_token_to_piece_impl(vocab, token1, text_buf_1, sizeof(text_buf_1) - 1, 0, false);
if (len0 <= 0 || len1 <= 0) {
return false;
}
return len0 <= len1 && memcmp(text_buf_0, text_buf_1, len0) == 0;
}
int32_t llama_detokenize_impl( int32_t llama_detokenize_impl(
const struct llama_vocab & vocab, const struct llama_vocab & vocab,
const llama_token * tokens, const llama_token * tokens,

View File

@ -149,6 +149,12 @@ int32_t llama_token_to_piece_impl(
int32_t lstrip, int32_t lstrip,
bool special); bool special);
// check if token0 is contained as a prefix in token1
bool llama_token_is_prefix_impl(
const struct llama_vocab & vocab,
llama_token token0,
llama_token token1);
int32_t llama_detokenize_impl( int32_t llama_detokenize_impl(
const struct llama_vocab & vocab, const struct llama_vocab & vocab,
const llama_token * tokens, const llama_token * tokens,

View File

@ -21500,6 +21500,13 @@ int32_t llama_token_to_piece(
return llama_token_to_piece_impl(model->vocab, token, buf, length, lstrip, special); return llama_token_to_piece_impl(model->vocab, token, buf, length, lstrip, special);
} }
bool llama_token_is_prefix(
const struct llama_model * model,
llama_token token0,
llama_token token1) {
return llama_token_is_prefix_impl(model->vocab, token0, token1);
}
int32_t llama_detokenize( int32_t llama_detokenize(
const struct llama_model * model, const struct llama_model * model,
const llama_token * tokens, const llama_token * tokens,
@ -21830,6 +21837,10 @@ struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * mod
return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root); return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root);
} }
struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model) {
return llama_sampler_init_infill_impl(model->vocab);
}
// //
// model split // model split
// //