mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-14 14:59:52 +00:00
llama : add infill sampler (#9896)
Some checks are pending
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full-cuda.Dockerfile platforms:linux/amd64 tag:full-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full-musa.Dockerfile platforms:linux/amd64 tag:full-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full.Dockerfile platforms:linux/amd64,linux/arm64 tag:full]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-cuda.Dockerfile platforms:linux/amd64 tag:light-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-intel.Dockerfile platforms:linux/amd64 tag:light-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-musa.Dockerfile platforms:linux/amd64 tag:light-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli.Dockerfile platforms:linux/amd64,linux/arm64 tag:light]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-cuda.Dockerfile platforms:linux/amd64 tag:server-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-intel.Dockerfile platforms:linux/amd64 tag:server-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-musa.Dockerfile platforms:linux/amd64 tag:server-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server.Dockerfile platforms:linux/amd64,linux/arm64 tag:server]) (push) Waiting to run
Nix CI / nix-eval (macos-latest) (push) Waiting to run
Nix CI / nix-eval (ubuntu-latest) (push) Waiting to run
Nix CI / nix-build (macos-latest) (push) Waiting to run
Nix CI / nix-build (ubuntu-latest) (push) Waiting to run
flake8 Lint / Lint (push) Waiting to run
Some checks are pending
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full-cuda.Dockerfile platforms:linux/amd64 tag:full-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full-musa.Dockerfile platforms:linux/amd64 tag:full-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full.Dockerfile platforms:linux/amd64,linux/arm64 tag:full]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-cuda.Dockerfile platforms:linux/amd64 tag:light-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-intel.Dockerfile platforms:linux/amd64 tag:light-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-musa.Dockerfile platforms:linux/amd64 tag:light-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli.Dockerfile platforms:linux/amd64,linux/arm64 tag:light]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-cuda.Dockerfile platforms:linux/amd64 tag:server-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-intel.Dockerfile platforms:linux/amd64 tag:server-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-musa.Dockerfile platforms:linux/amd64 tag:server-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server.Dockerfile platforms:linux/amd64,linux/arm64 tag:server]) (push) Waiting to run
Nix CI / nix-eval (macos-latest) (push) Waiting to run
Nix CI / nix-eval (ubuntu-latest) (push) Waiting to run
Nix CI / nix-build (macos-latest) (push) Waiting to run
Nix CI / nix-build (ubuntu-latest) (push) Waiting to run
flake8 Lint / Lint (push) Waiting to run
ggml-ci
This commit is contained in:
parent
223c25a72f
commit
755a9b2bf0
@ -91,7 +91,7 @@ enum common_sampler_type {
|
|||||||
COMMON_SAMPLER_TYPE_TYPICAL_P = 5,
|
COMMON_SAMPLER_TYPE_TYPICAL_P = 5,
|
||||||
COMMON_SAMPLER_TYPE_TEMPERATURE = 6,
|
COMMON_SAMPLER_TYPE_TEMPERATURE = 6,
|
||||||
COMMON_SAMPLER_TYPE_XTC = 7,
|
COMMON_SAMPLER_TYPE_XTC = 7,
|
||||||
|
COMMON_SAMPLER_TYPE_INFILL = 8,
|
||||||
};
|
};
|
||||||
|
|
||||||
// dimensionality reduction methods, used by cvector-generator
|
// dimensionality reduction methods, used by cvector-generator
|
||||||
@ -136,7 +136,7 @@ struct common_sampler_params {
|
|||||||
COMMON_SAMPLER_TYPE_TOP_P,
|
COMMON_SAMPLER_TYPE_TOP_P,
|
||||||
COMMON_SAMPLER_TYPE_MIN_P,
|
COMMON_SAMPLER_TYPE_MIN_P,
|
||||||
COMMON_SAMPLER_TYPE_XTC,
|
COMMON_SAMPLER_TYPE_XTC,
|
||||||
COMMON_SAMPLER_TYPE_TEMPERATURE
|
COMMON_SAMPLER_TYPE_TEMPERATURE,
|
||||||
};
|
};
|
||||||
|
|
||||||
std::string grammar; // optional BNF-like grammar to constrain sampling
|
std::string grammar; // optional BNF-like grammar to constrain sampling
|
||||||
|
@ -196,6 +196,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||||
break;
|
break;
|
||||||
|
case COMMON_SAMPLER_TYPE_INFILL:
|
||||||
|
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false && "unknown sampler type");
|
GGML_ASSERT(false && "unknown sampler type");
|
||||||
}
|
}
|
||||||
@ -376,6 +379,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
|
|||||||
case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
|
case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
|
||||||
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
|
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
|
||||||
case COMMON_SAMPLER_TYPE_XTC: return 'x';
|
case COMMON_SAMPLER_TYPE_XTC: return 'x';
|
||||||
|
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
|
||||||
default : return '?';
|
default : return '?';
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -389,6 +393,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
|
|||||||
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
|
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
|
||||||
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
|
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
|
||||||
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
|
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
|
||||||
|
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
|
||||||
default : return "";
|
default : return "";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -402,6 +407,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
|
|||||||
{ "tfs_z", COMMON_SAMPLER_TYPE_TFS_Z },
|
{ "tfs_z", COMMON_SAMPLER_TYPE_TFS_Z },
|
||||||
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
||||||
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
|
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
|
||||||
|
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
|
||||||
};
|
};
|
||||||
|
|
||||||
// since samplers names are written multiple ways
|
// since samplers names are written multiple ways
|
||||||
@ -448,7 +454,8 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
|
|||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
|
||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
|
||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
|
||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC }
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
|
||||||
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<common_sampler_type> samplers;
|
std::vector<common_sampler_type> samplers;
|
||||||
|
@ -569,7 +569,8 @@ int main(int argc, char ** argv) {
|
|||||||
if (!params.ctx_shift){
|
if (!params.ctx_shift){
|
||||||
LOG_DBG("\n\n%s: context full and context shift is disabled => stopping\n", __func__);
|
LOG_DBG("\n\n%s: context full and context shift is disabled => stopping\n", __func__);
|
||||||
break;
|
break;
|
||||||
} else {
|
}
|
||||||
|
|
||||||
if (params.n_predict == -2) {
|
if (params.n_predict == -2) {
|
||||||
LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
|
LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
|
||||||
break;
|
break;
|
||||||
@ -593,7 +594,6 @@ int main(int argc, char ** argv) {
|
|||||||
LOG_DBG("clear session path\n");
|
LOG_DBG("clear session path\n");
|
||||||
path_session.clear();
|
path_session.clear();
|
||||||
}
|
}
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
// context extension via Self-Extend
|
// context extension via Self-Extend
|
||||||
while (n_past >= ga_i + ga_w) {
|
while (n_past >= ga_i + ga_w) {
|
||||||
|
@ -953,6 +953,12 @@ extern "C" {
|
|||||||
int32_t lstrip,
|
int32_t lstrip,
|
||||||
bool special);
|
bool special);
|
||||||
|
|
||||||
|
// check if token0 is contained as a prefix in token1
|
||||||
|
LLAMA_API bool llama_token_is_prefix(
|
||||||
|
const struct llama_model * model,
|
||||||
|
llama_token token0,
|
||||||
|
llama_token token1);
|
||||||
|
|
||||||
/// @details Convert the provided tokens into text (inverse of llama_tokenize()).
|
/// @details Convert the provided tokens into text (inverse of llama_tokenize()).
|
||||||
/// @param text The char pointer must be large enough to hold the resulting text.
|
/// @param text The char pointer must be large enough to hold the resulting text.
|
||||||
/// @return Returns the number of chars/bytes on success, no more than text_len_max.
|
/// @return Returns the number of chars/bytes on success, no more than text_len_max.
|
||||||
@ -1148,6 +1154,28 @@ extern "C" {
|
|||||||
int32_t n_logit_bias,
|
int32_t n_logit_bias,
|
||||||
const llama_logit_bias * logit_bias);
|
const llama_logit_bias * logit_bias);
|
||||||
|
|
||||||
|
// this sampler is meant to be used for fill-in-the-middle infilling
|
||||||
|
// it's supposed to be used after top_k + top_p sampling
|
||||||
|
//
|
||||||
|
// 1. if the sum of the EOG probs times the number of candidates is higher than the sum of the other probs -> pick EOG
|
||||||
|
// 2. combine probs of tokens that have the same prefix
|
||||||
|
//
|
||||||
|
// example:
|
||||||
|
//
|
||||||
|
// - before:
|
||||||
|
// "hel": 0.5
|
||||||
|
// "hell": 0.2
|
||||||
|
// "hello": 0.1
|
||||||
|
// "dummy": 0.1
|
||||||
|
//
|
||||||
|
// - after:
|
||||||
|
// "hel": 0.8
|
||||||
|
// "dummy": 0.1
|
||||||
|
//
|
||||||
|
// 3. discard non-EOG tokens with low prob
|
||||||
|
// 4. if no tokens are left -> pick EOT
|
||||||
|
//
|
||||||
|
LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model);
|
||||||
|
|
||||||
// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
|
// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
|
||||||
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
|
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
|
||||||
|
@ -1739,6 +1739,207 @@ struct llama_sampler * llama_sampler_init_logit_bias(
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// infill
|
||||||
|
|
||||||
|
//#define GGML_DEBUG_SAMPLER_INFILL
|
||||||
|
|
||||||
|
struct llama_sampler_infill {
|
||||||
|
const struct llama_vocab * vocab;
|
||||||
|
};
|
||||||
|
|
||||||
|
static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
|
||||||
|
return "infill";
|
||||||
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
|
auto * ctx = (llama_sampler_infill *) smpl->ctx;
|
||||||
|
|
||||||
|
llama_sampler_softmax_impl(cur_p);
|
||||||
|
|
||||||
|
#if defined(GGML_DEBUG_SAMPLER_INFILL)
|
||||||
|
#define LOG_DBG_CUR LLAMA_LOG_DEBUG
|
||||||
|
#else
|
||||||
|
#define LOG_DBG_CUR(...)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
|
||||||
|
}
|
||||||
|
|
||||||
|
float p_txt_sum = 0.0f;
|
||||||
|
float p_eog_sum = 0.0f;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
|
||||||
|
p_eog_sum += cur_p->data[i].p;
|
||||||
|
} else {
|
||||||
|
p_txt_sum += cur_p->data[i].p;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED(rat);
|
||||||
|
|
||||||
|
LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size);
|
||||||
|
|
||||||
|
if (3*p_eog_sum*cur_p->size > p_txt_sum) {
|
||||||
|
LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum);
|
||||||
|
|
||||||
|
// keep just the EOG tokens
|
||||||
|
const auto size_org = cur_p->size;
|
||||||
|
|
||||||
|
cur_p->size = 0;
|
||||||
|
|
||||||
|
float p_sum = 0.0f;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < size_org; ++i) {
|
||||||
|
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
|
||||||
|
p_sum += cur_p->data[i].p;
|
||||||
|
|
||||||
|
cur_p->data[cur_p->size++] = cur_p->data[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalize probs
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
cur_p->data[i].p /= p_sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t n_combined = 0; GGML_UNUSED(n_combined);
|
||||||
|
|
||||||
|
// combine tokens with common prefix
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
for (size_t j = 0; j < cur_p->size; ++j) {
|
||||||
|
if (cur_p->data[i].logit == -INFINITY) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (i == j || cur_p->data[j].logit == -INFINITY) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (llama_token_is_prefix_impl(*ctx->vocab, cur_p->data[i].id, cur_p->data[j].id)) {
|
||||||
|
if (cur_p->data[i].p > cur_p->data[j].p) {
|
||||||
|
cur_p->data[i].p += cur_p->data[j].p;
|
||||||
|
cur_p->data[j].logit = -INFINITY;
|
||||||
|
cur_p->data[j].p = 0.0f;
|
||||||
|
} else {
|
||||||
|
cur_p->data[j].p += cur_p->data[i].p;
|
||||||
|
cur_p->data[i].logit = -INFINITY;
|
||||||
|
cur_p->data[i].p = 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
n_combined++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t n_non_eog = 0;
|
||||||
|
|
||||||
|
size_t size_org = cur_p->size;
|
||||||
|
|
||||||
|
float p_sum = 0.0f;
|
||||||
|
float thold = 0.2f;
|
||||||
|
|
||||||
|
cur_p->size = 0;
|
||||||
|
|
||||||
|
LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < size_org; ++i) {
|
||||||
|
const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
|
||||||
|
|
||||||
|
if (cur_p->data[i].p < thold && !is_eog) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!is_eog) {
|
||||||
|
++n_non_eog;
|
||||||
|
}
|
||||||
|
|
||||||
|
p_sum += cur_p->data[i].p;
|
||||||
|
|
||||||
|
// keep this token
|
||||||
|
cur_p->data[cur_p->size++] = cur_p->data[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_DBG_CUR("%s: n_non_eog = %zu\n", __func__, n_non_eog);
|
||||||
|
|
||||||
|
// if no non-EOG tokens are left -> reduce cur_p to single EOT token
|
||||||
|
if (n_non_eog == 0) {
|
||||||
|
cur_p->size = 1;
|
||||||
|
cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab);
|
||||||
|
cur_p->data[0].logit = 1.0f;
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalize probs
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
cur_p->data[i].p /= p_sum;
|
||||||
|
|
||||||
|
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_org = cur_p->size;
|
||||||
|
p_sum = 0.0f;
|
||||||
|
thold = 1.0/(n_non_eog + 1);
|
||||||
|
|
||||||
|
cur_p->size = 0;
|
||||||
|
|
||||||
|
LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < size_org; ++i) {
|
||||||
|
const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
|
||||||
|
|
||||||
|
if (cur_p->data[i].p < thold && !is_eog) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
p_sum += cur_p->data[i].p;
|
||||||
|
|
||||||
|
cur_p->data[cur_p->size++] = cur_p->data[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalize probs
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
cur_p->data[i].p /= p_sum;
|
||||||
|
|
||||||
|
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
|
||||||
|
}
|
||||||
|
|
||||||
|
#undef LOG_DBG_CUR
|
||||||
|
}
|
||||||
|
|
||||||
|
static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
|
||||||
|
const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
|
||||||
|
return llama_sampler_init_infill_impl(*ctx->vocab);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_infill_free(struct llama_sampler * smpl) {
|
||||||
|
delete (llama_sampler_infill *) smpl->ctx;
|
||||||
|
}
|
||||||
|
|
||||||
|
static struct llama_sampler_i llama_sampler_infill_i = {
|
||||||
|
/* .name = */ llama_sampler_infill_name,
|
||||||
|
/* .accept = */ nullptr,
|
||||||
|
/* .apply = */ llama_sampler_infill_apply,
|
||||||
|
/* .reset = */ nullptr,
|
||||||
|
/* .clone = */ llama_sampler_infill_clone,
|
||||||
|
/* .free = */ llama_sampler_infill_free,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llama_sampler * llama_sampler_init_infill_impl(
|
||||||
|
const struct llama_vocab & vocab) {
|
||||||
|
return new llama_sampler {
|
||||||
|
/* .iface = */ &llama_sampler_infill_i,
|
||||||
|
/* .ctx = */ new llama_sampler_infill {
|
||||||
|
/* .vocab = */ &vocab,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
// utils
|
// utils
|
||||||
|
|
||||||
uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
|
uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
|
||||||
|
@ -4,8 +4,6 @@
|
|||||||
|
|
||||||
#include "llama-grammar.h"
|
#include "llama-grammar.h"
|
||||||
|
|
||||||
#include <unordered_map>
|
|
||||||
|
|
||||||
struct llama_vocab;
|
struct llama_vocab;
|
||||||
struct llama_grammar;
|
struct llama_grammar;
|
||||||
|
|
||||||
@ -27,3 +25,6 @@ struct llama_sampler * llama_sampler_init_grammar_impl(
|
|||||||
const struct llama_vocab & vocab,
|
const struct llama_vocab & vocab,
|
||||||
const char * grammar_str,
|
const char * grammar_str,
|
||||||
const char * grammar_root);
|
const char * grammar_root);
|
||||||
|
|
||||||
|
struct llama_sampler * llama_sampler_init_infill_impl(
|
||||||
|
const struct llama_vocab & vocab);
|
||||||
|
@ -1858,6 +1858,23 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool llama_token_is_prefix_impl(
|
||||||
|
const struct llama_vocab & vocab,
|
||||||
|
llama_token token0,
|
||||||
|
llama_token token1) {
|
||||||
|
char text_buf_0[128];
|
||||||
|
char text_buf_1[128];
|
||||||
|
|
||||||
|
const int32_t len0 = llama_token_to_piece_impl(vocab, token0, text_buf_0, sizeof(text_buf_0) - 1, 0, false);
|
||||||
|
const int32_t len1 = llama_token_to_piece_impl(vocab, token1, text_buf_1, sizeof(text_buf_1) - 1, 0, false);
|
||||||
|
|
||||||
|
if (len0 <= 0 || len1 <= 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return len0 <= len1 && memcmp(text_buf_0, text_buf_1, len0) == 0;
|
||||||
|
}
|
||||||
|
|
||||||
int32_t llama_detokenize_impl(
|
int32_t llama_detokenize_impl(
|
||||||
const struct llama_vocab & vocab,
|
const struct llama_vocab & vocab,
|
||||||
const llama_token * tokens,
|
const llama_token * tokens,
|
||||||
|
@ -149,6 +149,12 @@ int32_t llama_token_to_piece_impl(
|
|||||||
int32_t lstrip,
|
int32_t lstrip,
|
||||||
bool special);
|
bool special);
|
||||||
|
|
||||||
|
// check if token0 is contained as a prefix in token1
|
||||||
|
bool llama_token_is_prefix_impl(
|
||||||
|
const struct llama_vocab & vocab,
|
||||||
|
llama_token token0,
|
||||||
|
llama_token token1);
|
||||||
|
|
||||||
int32_t llama_detokenize_impl(
|
int32_t llama_detokenize_impl(
|
||||||
const struct llama_vocab & vocab,
|
const struct llama_vocab & vocab,
|
||||||
const llama_token * tokens,
|
const llama_token * tokens,
|
||||||
|
@ -21500,6 +21500,13 @@ int32_t llama_token_to_piece(
|
|||||||
return llama_token_to_piece_impl(model->vocab, token, buf, length, lstrip, special);
|
return llama_token_to_piece_impl(model->vocab, token, buf, length, lstrip, special);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool llama_token_is_prefix(
|
||||||
|
const struct llama_model * model,
|
||||||
|
llama_token token0,
|
||||||
|
llama_token token1) {
|
||||||
|
return llama_token_is_prefix_impl(model->vocab, token0, token1);
|
||||||
|
}
|
||||||
|
|
||||||
int32_t llama_detokenize(
|
int32_t llama_detokenize(
|
||||||
const struct llama_model * model,
|
const struct llama_model * model,
|
||||||
const llama_token * tokens,
|
const llama_token * tokens,
|
||||||
@ -21830,6 +21837,10 @@ struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * mod
|
|||||||
return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root);
|
return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model) {
|
||||||
|
return llama_sampler_init_infill_impl(model->vocab);
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// model split
|
// model split
|
||||||
//
|
//
|
||||||
|
Loading…
Reference in New Issue
Block a user