mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 18:34:36 +00:00
llama : add infill sampler (#9896)
Some checks are pending
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full-cuda.Dockerfile platforms:linux/amd64 tag:full-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full-musa.Dockerfile platforms:linux/amd64 tag:full-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full.Dockerfile platforms:linux/amd64,linux/arm64 tag:full]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-cuda.Dockerfile platforms:linux/amd64 tag:light-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-intel.Dockerfile platforms:linux/amd64 tag:light-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-musa.Dockerfile platforms:linux/amd64 tag:light-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli.Dockerfile platforms:linux/amd64,linux/arm64 tag:light]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-cuda.Dockerfile platforms:linux/amd64 tag:server-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-intel.Dockerfile platforms:linux/amd64 tag:server-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-musa.Dockerfile platforms:linux/amd64 tag:server-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server.Dockerfile platforms:linux/amd64,linux/arm64 tag:server]) (push) Waiting to run
Nix CI / nix-eval (macos-latest) (push) Waiting to run
Nix CI / nix-eval (ubuntu-latest) (push) Waiting to run
Nix CI / nix-build (macos-latest) (push) Waiting to run
Nix CI / nix-build (ubuntu-latest) (push) Waiting to run
flake8 Lint / Lint (push) Waiting to run
Some checks are pending
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full-cuda.Dockerfile platforms:linux/amd64 tag:full-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full-musa.Dockerfile platforms:linux/amd64 tag:full-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full.Dockerfile platforms:linux/amd64,linux/arm64 tag:full]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-cuda.Dockerfile platforms:linux/amd64 tag:light-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-intel.Dockerfile platforms:linux/amd64 tag:light-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-musa.Dockerfile platforms:linux/amd64 tag:light-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli.Dockerfile platforms:linux/amd64,linux/arm64 tag:light]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-cuda.Dockerfile platforms:linux/amd64 tag:server-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-intel.Dockerfile platforms:linux/amd64 tag:server-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-musa.Dockerfile platforms:linux/amd64 tag:server-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server.Dockerfile platforms:linux/amd64,linux/arm64 tag:server]) (push) Waiting to run
Nix CI / nix-eval (macos-latest) (push) Waiting to run
Nix CI / nix-eval (ubuntu-latest) (push) Waiting to run
Nix CI / nix-build (macos-latest) (push) Waiting to run
Nix CI / nix-build (ubuntu-latest) (push) Waiting to run
flake8 Lint / Lint (push) Waiting to run
ggml-ci
This commit is contained in:
parent
223c25a72f
commit
755a9b2bf0
@ -91,7 +91,7 @@ enum common_sampler_type {
|
||||
COMMON_SAMPLER_TYPE_TYPICAL_P = 5,
|
||||
COMMON_SAMPLER_TYPE_TEMPERATURE = 6,
|
||||
COMMON_SAMPLER_TYPE_XTC = 7,
|
||||
|
||||
COMMON_SAMPLER_TYPE_INFILL = 8,
|
||||
};
|
||||
|
||||
// dimensionality reduction methods, used by cvector-generator
|
||||
@ -136,7 +136,7 @@ struct common_sampler_params {
|
||||
COMMON_SAMPLER_TYPE_TOP_P,
|
||||
COMMON_SAMPLER_TYPE_MIN_P,
|
||||
COMMON_SAMPLER_TYPE_XTC,
|
||||
COMMON_SAMPLER_TYPE_TEMPERATURE
|
||||
COMMON_SAMPLER_TYPE_TEMPERATURE,
|
||||
};
|
||||
|
||||
std::string grammar; // optional BNF-like grammar to constrain sampling
|
||||
|
@ -196,6 +196,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_INFILL:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown sampler type");
|
||||
}
|
||||
@ -376,6 +379,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
|
||||
case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
|
||||
case COMMON_SAMPLER_TYPE_XTC: return 'x';
|
||||
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
|
||||
default : return '?';
|
||||
}
|
||||
}
|
||||
@ -389,6 +393,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
|
||||
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
|
||||
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
|
||||
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
|
||||
default : return "";
|
||||
}
|
||||
}
|
||||
@ -402,6 +407,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
|
||||
{ "tfs_z", COMMON_SAMPLER_TYPE_TFS_Z },
|
||||
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
||||
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
|
||||
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
|
||||
};
|
||||
|
||||
// since samplers names are written multiple ways
|
||||
@ -448,7 +454,8 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC }
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
|
||||
};
|
||||
|
||||
std::vector<common_sampler_type> samplers;
|
||||
|
@ -569,30 +569,30 @@ int main(int argc, char ** argv) {
|
||||
if (!params.ctx_shift){
|
||||
LOG_DBG("\n\n%s: context full and context shift is disabled => stopping\n", __func__);
|
||||
break;
|
||||
} else {
|
||||
if (params.n_predict == -2) {
|
||||
LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
|
||||
break;
|
||||
}
|
||||
|
||||
const int n_left = n_past - params.n_keep;
|
||||
const int n_discard = n_left/2;
|
||||
|
||||
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
|
||||
n_past, n_left, n_ctx, params.n_keep, n_discard);
|
||||
|
||||
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
|
||||
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
|
||||
|
||||
n_past -= n_discard;
|
||||
|
||||
LOG_DBG("after swap: n_past = %d\n", n_past);
|
||||
|
||||
LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());
|
||||
|
||||
LOG_DBG("clear session path\n");
|
||||
path_session.clear();
|
||||
}
|
||||
|
||||
if (params.n_predict == -2) {
|
||||
LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
|
||||
break;
|
||||
}
|
||||
|
||||
const int n_left = n_past - params.n_keep;
|
||||
const int n_discard = n_left/2;
|
||||
|
||||
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
|
||||
n_past, n_left, n_ctx, params.n_keep, n_discard);
|
||||
|
||||
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
|
||||
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
|
||||
|
||||
n_past -= n_discard;
|
||||
|
||||
LOG_DBG("after swap: n_past = %d\n", n_past);
|
||||
|
||||
LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());
|
||||
|
||||
LOG_DBG("clear session path\n");
|
||||
path_session.clear();
|
||||
}
|
||||
} else {
|
||||
// context extension via Self-Extend
|
||||
|
@ -953,6 +953,12 @@ extern "C" {
|
||||
int32_t lstrip,
|
||||
bool special);
|
||||
|
||||
// check if token0 is contained as a prefix in token1
|
||||
LLAMA_API bool llama_token_is_prefix(
|
||||
const struct llama_model * model,
|
||||
llama_token token0,
|
||||
llama_token token1);
|
||||
|
||||
/// @details Convert the provided tokens into text (inverse of llama_tokenize()).
|
||||
/// @param text The char pointer must be large enough to hold the resulting text.
|
||||
/// @return Returns the number of chars/bytes on success, no more than text_len_max.
|
||||
@ -1148,6 +1154,28 @@ extern "C" {
|
||||
int32_t n_logit_bias,
|
||||
const llama_logit_bias * logit_bias);
|
||||
|
||||
// this sampler is meant to be used for fill-in-the-middle infilling
|
||||
// it's supposed to be used after top_k + top_p sampling
|
||||
//
|
||||
// 1. if the sum of the EOG probs times the number of candidates is higher than the sum of the other probs -> pick EOG
|
||||
// 2. combine probs of tokens that have the same prefix
|
||||
//
|
||||
// example:
|
||||
//
|
||||
// - before:
|
||||
// "hel": 0.5
|
||||
// "hell": 0.2
|
||||
// "hello": 0.1
|
||||
// "dummy": 0.1
|
||||
//
|
||||
// - after:
|
||||
// "hel": 0.8
|
||||
// "dummy": 0.1
|
||||
//
|
||||
// 3. discard non-EOG tokens with low prob
|
||||
// 4. if no tokens are left -> pick EOT
|
||||
//
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model);
|
||||
|
||||
// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
|
||||
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
|
||||
|
@ -1739,6 +1739,207 @@ struct llama_sampler * llama_sampler_init_logit_bias(
|
||||
};
|
||||
}
|
||||
|
||||
// infill
|
||||
|
||||
//#define GGML_DEBUG_SAMPLER_INFILL
|
||||
|
||||
struct llama_sampler_infill {
|
||||
const struct llama_vocab * vocab;
|
||||
};
|
||||
|
||||
static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
|
||||
return "infill";
|
||||
}
|
||||
|
||||
static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
auto * ctx = (llama_sampler_infill *) smpl->ctx;
|
||||
|
||||
llama_sampler_softmax_impl(cur_p);
|
||||
|
||||
#if defined(GGML_DEBUG_SAMPLER_INFILL)
|
||||
#define LOG_DBG_CUR LLAMA_LOG_DEBUG
|
||||
#else
|
||||
#define LOG_DBG_CUR(...)
|
||||
#endif
|
||||
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
|
||||
}
|
||||
|
||||
float p_txt_sum = 0.0f;
|
||||
float p_eog_sum = 0.0f;
|
||||
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
|
||||
p_eog_sum += cur_p->data[i].p;
|
||||
} else {
|
||||
p_txt_sum += cur_p->data[i].p;
|
||||
}
|
||||
}
|
||||
|
||||
const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED(rat);
|
||||
|
||||
LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size);
|
||||
|
||||
if (3*p_eog_sum*cur_p->size > p_txt_sum) {
|
||||
LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum);
|
||||
|
||||
// keep just the EOG tokens
|
||||
const auto size_org = cur_p->size;
|
||||
|
||||
cur_p->size = 0;
|
||||
|
||||
float p_sum = 0.0f;
|
||||
|
||||
for (size_t i = 0; i < size_org; ++i) {
|
||||
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
|
||||
p_sum += cur_p->data[i].p;
|
||||
|
||||
cur_p->data[cur_p->size++] = cur_p->data[i];
|
||||
}
|
||||
}
|
||||
|
||||
// normalize probs
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
cur_p->data[i].p /= p_sum;
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
size_t n_combined = 0; GGML_UNUSED(n_combined);
|
||||
|
||||
// combine tokens with common prefix
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
for (size_t j = 0; j < cur_p->size; ++j) {
|
||||
if (cur_p->data[i].logit == -INFINITY) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (i == j || cur_p->data[j].logit == -INFINITY) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (llama_token_is_prefix_impl(*ctx->vocab, cur_p->data[i].id, cur_p->data[j].id)) {
|
||||
if (cur_p->data[i].p > cur_p->data[j].p) {
|
||||
cur_p->data[i].p += cur_p->data[j].p;
|
||||
cur_p->data[j].logit = -INFINITY;
|
||||
cur_p->data[j].p = 0.0f;
|
||||
} else {
|
||||
cur_p->data[j].p += cur_p->data[i].p;
|
||||
cur_p->data[i].logit = -INFINITY;
|
||||
cur_p->data[i].p = 0.0f;
|
||||
}
|
||||
|
||||
n_combined++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t n_non_eog = 0;
|
||||
|
||||
size_t size_org = cur_p->size;
|
||||
|
||||
float p_sum = 0.0f;
|
||||
float thold = 0.2f;
|
||||
|
||||
cur_p->size = 0;
|
||||
|
||||
LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);
|
||||
|
||||
for (size_t i = 0; i < size_org; ++i) {
|
||||
const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
|
||||
|
||||
if (cur_p->data[i].p < thold && !is_eog) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!is_eog) {
|
||||
++n_non_eog;
|
||||
}
|
||||
|
||||
p_sum += cur_p->data[i].p;
|
||||
|
||||
// keep this token
|
||||
cur_p->data[cur_p->size++] = cur_p->data[i];
|
||||
}
|
||||
|
||||
LOG_DBG_CUR("%s: n_non_eog = %zu\n", __func__, n_non_eog);
|
||||
|
||||
// if no non-EOG tokens are left -> reduce cur_p to single EOT token
|
||||
if (n_non_eog == 0) {
|
||||
cur_p->size = 1;
|
||||
cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab);
|
||||
cur_p->data[0].logit = 1.0f;
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// normalize probs
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
cur_p->data[i].p /= p_sum;
|
||||
|
||||
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
|
||||
}
|
||||
|
||||
size_org = cur_p->size;
|
||||
p_sum = 0.0f;
|
||||
thold = 1.0/(n_non_eog + 1);
|
||||
|
||||
cur_p->size = 0;
|
||||
|
||||
LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);
|
||||
|
||||
for (size_t i = 0; i < size_org; ++i) {
|
||||
const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
|
||||
|
||||
if (cur_p->data[i].p < thold && !is_eog) {
|
||||
continue;
|
||||
}
|
||||
|
||||
p_sum += cur_p->data[i].p;
|
||||
|
||||
cur_p->data[cur_p->size++] = cur_p->data[i];
|
||||
}
|
||||
|
||||
// normalize probs
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
cur_p->data[i].p /= p_sum;
|
||||
|
||||
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
|
||||
}
|
||||
|
||||
#undef LOG_DBG_CUR
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
|
||||
return llama_sampler_init_infill_impl(*ctx->vocab);
|
||||
}
|
||||
|
||||
static void llama_sampler_infill_free(struct llama_sampler * smpl) {
|
||||
delete (llama_sampler_infill *) smpl->ctx;
|
||||
}
|
||||
|
||||
static struct llama_sampler_i llama_sampler_infill_i = {
|
||||
/* .name = */ llama_sampler_infill_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_infill_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ llama_sampler_infill_clone,
|
||||
/* .free = */ llama_sampler_infill_free,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_infill_impl(
|
||||
const struct llama_vocab & vocab) {
|
||||
return new llama_sampler {
|
||||
/* .iface = */ &llama_sampler_infill_i,
|
||||
/* .ctx = */ new llama_sampler_infill {
|
||||
/* .vocab = */ &vocab,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// utils
|
||||
|
||||
uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
|
||||
|
@ -4,8 +4,6 @@
|
||||
|
||||
#include "llama-grammar.h"
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
struct llama_vocab;
|
||||
struct llama_grammar;
|
||||
|
||||
@ -27,3 +25,6 @@ struct llama_sampler * llama_sampler_init_grammar_impl(
|
||||
const struct llama_vocab & vocab,
|
||||
const char * grammar_str,
|
||||
const char * grammar_root);
|
||||
|
||||
struct llama_sampler * llama_sampler_init_infill_impl(
|
||||
const struct llama_vocab & vocab);
|
||||
|
@ -1858,6 +1858,23 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool llama_token_is_prefix_impl(
|
||||
const struct llama_vocab & vocab,
|
||||
llama_token token0,
|
||||
llama_token token1) {
|
||||
char text_buf_0[128];
|
||||
char text_buf_1[128];
|
||||
|
||||
const int32_t len0 = llama_token_to_piece_impl(vocab, token0, text_buf_0, sizeof(text_buf_0) - 1, 0, false);
|
||||
const int32_t len1 = llama_token_to_piece_impl(vocab, token1, text_buf_1, sizeof(text_buf_1) - 1, 0, false);
|
||||
|
||||
if (len0 <= 0 || len1 <= 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return len0 <= len1 && memcmp(text_buf_0, text_buf_1, len0) == 0;
|
||||
}
|
||||
|
||||
int32_t llama_detokenize_impl(
|
||||
const struct llama_vocab & vocab,
|
||||
const llama_token * tokens,
|
||||
|
@ -48,7 +48,7 @@ struct llama_vocab {
|
||||
id special_cls_id = LLAMA_TOKEN_NULL;
|
||||
id special_mask_id = LLAMA_TOKEN_NULL;
|
||||
|
||||
id linefeed_id = 13;
|
||||
id linefeed_id = 13;
|
||||
|
||||
// fim tokens
|
||||
id special_fim_pre_id = LLAMA_TOKEN_NULL;
|
||||
@ -149,6 +149,12 @@ int32_t llama_token_to_piece_impl(
|
||||
int32_t lstrip,
|
||||
bool special);
|
||||
|
||||
// check if token0 is contained as a prefix in token1
|
||||
bool llama_token_is_prefix_impl(
|
||||
const struct llama_vocab & vocab,
|
||||
llama_token token0,
|
||||
llama_token token1);
|
||||
|
||||
int32_t llama_detokenize_impl(
|
||||
const struct llama_vocab & vocab,
|
||||
const llama_token * tokens,
|
||||
|
@ -21500,6 +21500,13 @@ int32_t llama_token_to_piece(
|
||||
return llama_token_to_piece_impl(model->vocab, token, buf, length, lstrip, special);
|
||||
}
|
||||
|
||||
bool llama_token_is_prefix(
|
||||
const struct llama_model * model,
|
||||
llama_token token0,
|
||||
llama_token token1) {
|
||||
return llama_token_is_prefix_impl(model->vocab, token0, token1);
|
||||
}
|
||||
|
||||
int32_t llama_detokenize(
|
||||
const struct llama_model * model,
|
||||
const llama_token * tokens,
|
||||
@ -21830,6 +21837,10 @@ struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * mod
|
||||
return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root);
|
||||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model) {
|
||||
return llama_sampler_init_infill_impl(model->vocab);
|
||||
}
|
||||
|
||||
//
|
||||
// model split
|
||||
//
|
||||
|
Loading…
Reference in New Issue
Block a user