From 7e2b5fb1dd180ba9988fe6447ce999edf5bef1ab Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 20 Oct 2023 18:02:50 +0300 Subject: [PATCH] sampling : add llama_sampling_print helper --- common/sampling.cpp | 14 ++++++++++++++ common/sampling.h | 7 +++++-- examples/infill/infill.cpp | 3 +-- examples/main/main.cpp | 3 +-- 4 files changed, 21 insertions(+), 6 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 3db2cede8..422292175 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -66,6 +66,20 @@ void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * ds dst->prev = src->prev; } +std::string llama_sampling_print(const llama_sampling_params & params) { + char result[1024]; + + snprintf(result, sizeof(result), + "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" + "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, typical_p = %.3f, temp = %.3f\n" + "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", + params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present, + params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, + params.mirostat, params.mirostat_eta, params.mirostat_tau); + + return std::string(result); +} + llama_token llama_sampling_sample( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, diff --git a/common/sampling.h b/common/sampling.h index ccfeb13af..d8ee5126a 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -30,8 +30,8 @@ typedef struct llama_sampling_params { // Classifier-Free Guidance // https://arxiv.org/abs/2306.17806 - std::string cfg_negative_prompt; // string to help guidance - float cfg_scale = 1.f; // How strong is guidance + std::string cfg_negative_prompt; // string to help guidance + float cfg_scale = 1.f; // how strong is guidance std::unordered_map logit_bias; // logit bias for specific tokens } llama_sampling_params; @@ -70,6 +70,9 @@ void llama_sampling_reset(llama_sampling_context * ctx); // Copy the sampler context void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst); +// Print sampling parameters into a string +std::string llama_sampling_print(const llama_sampling_params & params); + // this is a common sampling function used across the examples for convenience // it can serve as a starting point for implementing your own sampling function // Note: When using multiple sequences, it is the caller's responsibility to call diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 00843fef7..8f520e38e 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -358,8 +358,7 @@ int main(int argc, char ** argv) { LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str()); } } - LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n", - sparams.repeat_last_n, sparams.repeat_penalty, sparams.presence_penalty, sparams.frequency_penalty, sparams.top_k, sparams.tfs_z, sparams.top_p, sparams.typical_p, sparams.temp, sparams.mirostat, sparams.mirostat_eta, sparams.mirostat_tau); + LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str()); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); LOG_TEE("\n\n"); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index cc9974b28..d36e2a43b 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -415,8 +415,7 @@ int main(int argc, char ** argv) { } } } - LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, frequency_penalty = %f, presence_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n", - sparams.penalty_last_n, sparams.penalty_repeat, sparams.penalty_freq, sparams.penalty_present, sparams.top_k, sparams.tfs_z, sparams.top_p, sparams.typical_p, sparams.temp, sparams.mirostat, sparams.mirostat_eta, sparams.mirostat_tau); + LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str()); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); LOG_TEE("\n\n");