From a803333a4e6fc534c93afe90d741bc2388bdec87 Mon Sep 17 00:00:00 2001 From: Alexey Parfenov Date: Sun, 11 Feb 2024 13:43:31 +0000 Subject: [PATCH] common : use enums for sampler types (#5418) * common: use enums for sampler types * Apply suggestions from code review Co-authored-by: Georgi Gerganov * minor : spaces --------- Co-authored-by: Georgi Gerganov --- common/common.cpp | 117 +++++++++++++++++++++++++++++++------------- common/common.h | 7 ++- common/sampling.cpp | 31 +++++------- common/sampling.h | 20 +++++++- 4 files changed, 120 insertions(+), 55 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 9a489a553..f64da2cb6 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -340,13 +340,14 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { invalid_param = true; break; } - sparams.samplers_sequence = parse_samplers_input(argv[i]); + const auto sampler_names = string_split(argv[i], ';'); + sparams.samplers_sequence = sampler_types_from_names(sampler_names); } else if (arg == "--sampling-seq") { if (++i >= argc) { invalid_param = true; break; } - sparams.samplers_sequence = argv[i]; + sparams.samplers_sequence = sampler_types_from_chars(argv[i]); } else if (arg == "--top-p") { if (++i >= argc) { invalid_param = true; @@ -906,6 +907,14 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { const llama_sampling_params & sparams = params.sparams; + std::string sampler_type_chars; + std::string sampler_type_names; + for (const auto sampler_type : sparams.samplers_sequence) { + sampler_type_chars += static_cast(sampler_type); + sampler_type_names += sampler_type_to_name_string(sampler_type) + ";"; + } + sampler_type_names.pop_back(); + printf("\n"); printf("usage: %s [options]\n", argv[0]); printf("\n"); @@ -947,8 +956,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict); printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx); printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); - printf(" --samplers samplers that will be used for generation in the order, separated by \';\', for example: \"top_k;tfs;typical;top_p;min_p;temp\"\n"); - printf(" --sampling-seq simplified sequence for samplers that will be used (default: %s)\n", sparams.samplers_sequence.c_str()); + printf(" --samplers samplers that will be used for generation in the order, separated by \';\' (default: %s)\n", sampler_type_names.c_str()); + printf(" --sampling-seq simplified sequence for samplers that will be used (default: %s)\n", sampler_type_chars.c_str()); printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k); printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p); printf(" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p); @@ -1097,45 +1106,85 @@ std::string gpt_random_prompt(std::mt19937 & rng) { } // -// String parsing +// String utils // -std::string parse_samplers_input(std::string input) { - std::string output = ""; +std::vector string_split(std::string input, char separator) { + std::vector parts; + size_t separator_pos = input.find(separator); + while (separator_pos != std::string::npos) { + std::string part = input.substr(0, separator_pos); + parts.emplace_back(part); + input = input.substr(separator_pos + 1); + separator_pos = input.find(separator); + } + parts.emplace_back(input); + return parts; +} + +std::vector sampler_types_from_names(const std::vector & names) { // since samplers names are written multiple ways // make it ready for both system names and input names - std::unordered_map samplers_symbols { - {"top_k", 'k'}, - {"top-k", 'k'}, - {"top_p", 'p'}, - {"top-p", 'p'}, - {"nucleus", 'p'}, - {"typical_p", 'y'}, - {"typical-p", 'y'}, - {"typical", 'y'}, - {"min_p", 'm'}, - {"min-p", 'm'}, - {"tfs_z", 'f'}, - {"tfs-z", 'f'}, - {"tfs", 'f'}, - {"temp", 't'}, - {"temperature",'t'} + std::unordered_map sampler_name_map { + {"top_k", llama_sampler_type::TOP_K}, + {"top-k", llama_sampler_type::TOP_K}, + {"top_p", llama_sampler_type::TOP_P}, + {"top-p", llama_sampler_type::TOP_P}, + {"nucleus", llama_sampler_type::TOP_P}, + {"typical_p", llama_sampler_type::TYPICAL_P}, + {"typical-p", llama_sampler_type::TYPICAL_P}, + {"typical", llama_sampler_type::TYPICAL_P}, + {"min_p", llama_sampler_type::MIN_P}, + {"min-p", llama_sampler_type::MIN_P}, + {"tfs_z", llama_sampler_type::TFS_Z}, + {"tfs-z", llama_sampler_type::TFS_Z}, + {"tfs", llama_sampler_type::TFS_Z}, + {"temp", llama_sampler_type::TEMP}, + {"temperature", llama_sampler_type::TEMP} }; - // expected format example: "temp;top_k;tfs_z;typical_p;top_p;min_p" - size_t separator = input.find(';'); - while (separator != input.npos) { - std::string name = input.substr(0,separator); - input = input.substr(separator+1); - separator = input.find(';'); - if (samplers_symbols.find(name) != samplers_symbols.end()) { - output += samplers_symbols[name]; + std::vector sampler_types; + sampler_types.reserve(names.size()); + for (const auto& name : names) { + const auto sampler_item = sampler_name_map.find(name); + if (sampler_item != sampler_name_map.end()) { + sampler_types.push_back(sampler_item->second); } } - if (samplers_symbols.find(input) != samplers_symbols.end()) { - output += samplers_symbols[input]; + return sampler_types; +} + +std::vector sampler_types_from_chars(const std::string & names_string) { + std::unordered_map sampler_name_map { + {'k', llama_sampler_type::TOP_K}, + {'p', llama_sampler_type::TOP_P}, + {'y', llama_sampler_type::TYPICAL_P}, + {'m', llama_sampler_type::MIN_P}, + {'f', llama_sampler_type::TFS_Z}, + {'t', llama_sampler_type::TEMP} + }; + + std::vector sampler_types; + sampler_types.reserve(names_string.size()); + for (const auto & c : names_string) { + const auto sampler_item = sampler_name_map.find(c); + if (sampler_item != sampler_name_map.end()) { + sampler_types.push_back(sampler_item->second); + } + } + return sampler_types; +} + +std::string sampler_type_to_name_string(llama_sampler_type sampler_type) { + switch (sampler_type) { + case llama_sampler_type::TOP_K: return "top_k"; + case llama_sampler_type::TFS_Z: return "tfs_z"; + case llama_sampler_type::TYPICAL_P: return "typical_p"; + case llama_sampler_type::TOP_P: return "top_p"; + case llama_sampler_type::MIN_P: return "min_p"; + case llama_sampler_type::TEMP: return "temp"; + default : return ""; } - return output; } // diff --git a/common/common.h b/common/common.h index 62de25d6a..9bdd45cf9 100644 --- a/common/common.h +++ b/common/common.h @@ -162,10 +162,13 @@ std::string gpt_random_prompt(std::mt19937 & rng); void process_escapes(std::string& input); // -// String parsing +// String utils // -std::string parse_samplers_input(std::string input); +std::vector sampler_types_from_names(const std::vector & names); +std::vector sampler_types_from_chars(const std::string & names_string); +std::vector string_split(std::string input, char separator); +std::string sampler_type_to_name_string(llama_sampler_type sampler_type); // // Model utils diff --git a/common/sampling.cpp b/common/sampling.cpp index 82cbdecea..a001750da 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -103,15 +103,10 @@ std::string llama_sampling_print(const llama_sampling_params & params) { std::string llama_sampling_order_print(const llama_sampling_params & params) { std::string result = "CFG -> Penalties "; if (params.mirostat == 0) { - for (auto s : params.samplers_sequence) { - switch (s) { - case 'k': result += "-> top_k "; break; - case 'f': result += "-> tfs_z "; break; - case 'y': result += "-> typical_p "; break; - case 'p': result += "-> top_p "; break; - case 'm': result += "-> min_p "; break; - case 't': result += "-> temp "; break; - default : break; + for (auto sampler_type : params.samplers_sequence) { + const auto sampler_type_name = sampler_type_to_name_string(sampler_type); + if (!sampler_type_name.empty()) { + result += "-> " + sampler_type_name + " "; } } } else { @@ -135,16 +130,16 @@ static void sampler_queue( const float min_p = params.min_p; const float tfs_z = params.tfs_z; const float typical_p = params.typical_p; - const std::string & samplers_sequence = params.samplers_sequence; + const std::vector & samplers_sequence = params.samplers_sequence; - for (auto s : samplers_sequence) { - switch (s){ - case 'k': llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break; - case 'f': llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break; - case 'y': llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break; - case 'p': llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break; - case 'm': llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break; - case 't': + for (auto sampler_type : samplers_sequence) { + switch (sampler_type) { + case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break; + case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break; + case llama_sampler_type::TYPICAL_P: llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break; + case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break; + case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break; + case llama_sampler_type::TEMP: if (dynatemp_range > 0) { float dynatemp_min = std::max(0.0f, temp - dynatemp_range); float dynatemp_max = std::max(0.0f, temp + dynatemp_range); diff --git a/common/sampling.h b/common/sampling.h index 88899c094..2bd6a75d2 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -8,6 +8,16 @@ #include #include +// sampler types +enum class llama_sampler_type : char { + TOP_K = 'k', + TOP_P = 'p', + MIN_P = 'm', + TFS_Z = 'f', + TYPICAL_P = 'y', + TEMP = 't' +}; + // sampling parameters typedef struct llama_sampling_params { int32_t n_prev = 64; // number of previous tokens to remember @@ -28,7 +38,15 @@ typedef struct llama_sampling_params { float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate bool penalize_nl = true; // consider newlines as a repeatable token - std::string samplers_sequence = "kfypmt"; // top_k, tail_free, typical_p, top_p, min_p, temp + + std::vector samplers_sequence = { + llama_sampler_type::TOP_K, + llama_sampler_type::TFS_Z, + llama_sampler_type::TYPICAL_P, + llama_sampler_type::TOP_P, + llama_sampler_type::MIN_P, + llama_sampler_type::TEMP + }; std::string grammar; // optional BNF-like grammar to constrain sampling