diff --git a/Makefile b/Makefile index 24acb8013..9dc35410a 100644 --- a/Makefile +++ b/Makefile @@ -768,7 +768,7 @@ batched-bench: examples/batched-bench/batched-bench.cpp build-info.o ggml. $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) -quantize: examples/quantize/quantize.cpp build-info.o ggml.o llama.o $(OBJS) +quantize: examples/quantize/quantize.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) diff --git a/common/common.cpp b/common/common.cpp index 9ab8aef7e..d42fa131d 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -234,8 +234,54 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { return result; } +bool parse_kv_override(const char * data, std::vector & overrides) { + const char * sep = strchr(data, '='); + if (sep == nullptr || sep - data >= 128) { + fprintf(stderr, "%s: malformed KV override '%s'\n", __func__, data); + return false; + } + llama_model_kv_override kvo; + std::strncpy(kvo.key, data, sep - data); + kvo.key[sep - data] = 0; + sep++; + if (strncmp(sep, "int:", 4) == 0) { + sep += 4; + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; + kvo.val_i64 = std::atol(sep); + } else if (strncmp(sep, "float:", 6) == 0) { + sep += 6; + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT; + kvo.val_f64 = std::atof(sep); + } else if (strncmp(sep, "bool:", 5) == 0) { + sep += 5; + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL; + if (std::strcmp(sep, "true") == 0) { + kvo.val_bool = true; + } else if (std::strcmp(sep, "false") == 0) { + kvo.val_bool = false; + } else { + fprintf(stderr, "%s: invalid boolean value for KV override '%s'\n", __func__, data); + return false; + } + } else if (strncmp(sep, "str:", 4) == 0) { + sep += 4; + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR; + if (strlen(sep) > 127) { + fprintf(stderr, "%s: malformed KV override '%s', value cannot exceed 127 chars\n", __func__, data); + return false; + } + strncpy(kvo.val_str, sep, 127); + kvo.val_str[127] = '\0'; + } else { + fprintf(stderr, "%s: invalid type for KV override '%s'\n", __func__, data); + return false; + } + overrides.emplace_back(std::move(kvo)); + return true; +} + bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) { - llama_sampling_params& sparams = params.sparams; + llama_sampling_params & sparams = params.sparams; if (arg == "-s" || arg == "--seed") { if (++i >= argc) { @@ -1244,47 +1290,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa invalid_param = true; return true; } - char* sep = strchr(argv[i], '='); - if (sep == nullptr || sep - argv[i] >= 128) { - fprintf(stderr, "error: Malformed KV override: %s\n", argv[i]); - invalid_param = true; - return true; - } - struct llama_model_kv_override kvo; - std::strncpy(kvo.key, argv[i], sep - argv[i]); - kvo.key[sep - argv[i]] = 0; - sep++; - if (strncmp(sep, "int:", 4) == 0) { - sep += 4; - kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; - kvo.int_value = std::atol(sep); - } - else if (strncmp(sep, "float:", 6) == 0) { - sep += 6; - kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT; - kvo.float_value = std::atof(sep); - } - else if (strncmp(sep, "bool:", 5) == 0) { - sep += 5; - kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL; - if (std::strcmp(sep, "true") == 0) { - kvo.bool_value = true; - } - else if (std::strcmp(sep, "false") == 0) { - kvo.bool_value = false; - } - else { - fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]); - invalid_param = true; - return true; - } - } - else { + if (!parse_kv_override(argv[i], params.kv_overrides)) { fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]); invalid_param = true; return true; } - params.kv_overrides.push_back(kvo); return true; } #ifndef LOG_DISABLE_LOGS @@ -1555,7 +1565,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" path to dynamic lookup cache to use for lookup decoding (updated by generation)\n"); printf(" --override-kv KEY=TYPE:VALUE\n"); printf(" advanced option to override model metadata by key. may be specified multiple times.\n"); - printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); + printf(" types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); printf(" -ptc N, --print-token-count N\n"); printf(" print token count every N tokens (default: %d)\n", params.n_print); printf(" --check-tensors check model tensor data for invalid values\n"); diff --git a/common/common.h b/common/common.h index 0005f143e..96a28a6ce 100644 --- a/common/common.h +++ b/common/common.h @@ -171,6 +171,8 @@ struct gpt_params { std::string image = ""; // path to an image file }; +bool parse_kv_override(const char * data, std::vector & overrides); + bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params); bool gpt_params_parse(int argc, char ** argv, gpt_params & params); diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 98c0e93e4..71e7a727f 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -23,6 +23,7 @@ struct Stats { }; struct StatParams { + std::string dataset; std::string ofile = "imatrix.dat"; int n_output_frequency = 10; int verbosity = 1; @@ -46,7 +47,7 @@ private: std::vector m_src1_data; std::vector m_ids; // the expert ids from ggml_mul_mat_id // - void save_imatrix(const char * file_name) const; + void save_imatrix(const char * file_name, const char * dataset) const; void keep_imatrix(int ncall) const; }; @@ -199,7 +200,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * } void IMatrixCollector::save_imatrix() const { - save_imatrix(m_params.ofile.empty() ? "imatrix.dat" : m_params.ofile.c_str()); + save_imatrix(m_params.ofile.empty() ? "imatrix.dat" : m_params.ofile.c_str(), m_params.dataset.c_str()); } void IMatrixCollector::keep_imatrix(int ncall) const { @@ -207,24 +208,33 @@ void IMatrixCollector::keep_imatrix(int ncall) const { if (file_name.empty()) file_name = "imatrix.dat"; file_name += ".at_"; file_name += std::to_string(ncall); - save_imatrix(file_name.c_str()); + save_imatrix(file_name.c_str(), m_params.dataset.c_str()); } -void IMatrixCollector::save_imatrix(const char * fname) const { +void IMatrixCollector::save_imatrix(const char * fname, const char * dataset) const { std::ofstream out(fname, std::ios::binary); int n_entries = m_stats.size(); - out.write((const char*)&n_entries, sizeof(n_entries)); - for (auto& p : m_stats) { + out.write((const char *) &n_entries, sizeof(n_entries)); + for (const auto & p : m_stats) { int len = p.first.size(); - out.write((const char*)&len, sizeof(len)); + out.write((const char *) &len, sizeof(len)); out.write(p.first.c_str(), len); - out.write((const char*)&p.second.ncall, sizeof(p.second.ncall)); + out.write((const char *) &p.second.ncall, sizeof(p.second.ncall)); int nval = p.second.values.size(); - out.write((const char*)&nval, sizeof(nval)); - if (nval > 0) out.write((const char*)p.second.values.data(), nval*sizeof(float)); + out.write((const char *) &nval, sizeof(nval)); + if (nval > 0) out.write((const char *) p.second.values.data(), nval * sizeof(float)); } + + // Write the number of call the matrix was computed with + out.write((const char *) &m_last_call, sizeof(m_last_call)); + + // Write the dataset name at the end of the file to later on specify it in quantize + int n_dataset = strlen(dataset); + out.write((const char *) &n_dataset, sizeof(n_dataset)); + out.write(dataset, n_dataset); + if (m_params.verbosity > 0) { - fprintf(stderr, "\n%s: stored collected data after %d chunks in %s\n",__func__,m_last_call,fname); + fprintf(stderr, "\n%s: stored collected data after %d chunks in %s\n", __func__, m_last_call, fname); } } @@ -547,6 +557,29 @@ int main(int argc, char ** argv) { } } + gpt_params params; + params.n_batch = 512; + if (!gpt_params_parse(args.size(), args.data(), params)) { + return 1; + } + + params.logits_all = true; + params.n_batch = std::min(params.n_batch, params.n_ctx); + + print_build_info(); + + if (params.seed == LLAMA_DEFAULT_SEED) { + params.seed = time(NULL); + } + + fprintf(stderr, "%s: seed = %u\n", __func__, params.seed); + + std::mt19937 rng(params.seed); + if (params.random_prompt) { + params.prompt = gpt_random_prompt(rng); + } + + sparams.dataset = params.prompt_file; g_collector.set_parameters(std::move(sparams)); if (!combine_files.empty()) { @@ -585,28 +618,6 @@ int main(int argc, char ** argv) { } } - gpt_params params; - params.n_batch = 512; - if (!gpt_params_parse(args.size(), args.data(), params)) { - return 1; - } - - params.logits_all = true; - params.n_batch = std::min(params.n_batch, params.n_ctx); - - print_build_info(); - - if (params.seed == LLAMA_DEFAULT_SEED) { - params.seed = time(NULL); - } - - fprintf(stderr, "%s: seed = %u\n", __func__, params.seed); - - std::mt19937 rng(params.seed); - if (params.random_prompt) { - params.prompt = gpt_random_prompt(rng); - } - llama_backend_init(); llama_numa_init(params.numa); diff --git a/examples/quantize/CMakeLists.txt b/examples/quantize/CMakeLists.txt index 6f374a2bd..6b977fde8 100644 --- a/examples/quantize/CMakeLists.txt +++ b/examples/quantize/CMakeLists.txt @@ -1,6 +1,6 @@ set(TARGET quantize) add_executable(${TARGET} quantize.cpp) install(TARGETS ${TARGET} RUNTIME) -target_link_libraries(${TARGET} PRIVATE llama build_info ${CMAKE_THREAD_LIBS_INIT}) +target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) target_include_directories(${TARGET} PRIVATE ../../common) target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index da1850dfd..432cc2b4f 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -8,7 +8,6 @@ #include #include #include -#include struct quant_option { std::string name; @@ -53,6 +52,10 @@ static const std::vector QUANT_OPTIONS = { { "COPY", LLAMA_FTYPE_ALL_F32, "only copy tensors, no quantizing", }, }; +static const char * const LLM_KV_QUANTIZE_IMATRIX_FILE = "quantize.imatrix.file"; +static const char * const LLM_KV_QUANTIZE_IMATRIX_DATASET = "quantize.imatrix.dataset"; +static const char * const LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES = "quantize.imatrix.entries_count"; +static const char * const LLM_KV_QUANTIZE_IMATRIX_N_CHUNKS = "quantize.imatrix.chunks_count"; static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftype, std::string & ftype_str_out) { std::string ftype_str; @@ -113,7 +116,7 @@ static void usage(const char * executable) { exit(1); } -static void load_imatrix(const std::string & imatrix_file, std::unordered_map> & imatrix_data) { +static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_dataset, std::unordered_map> & imatrix_data) { std::ifstream in(imatrix_file.c_str(), std::ios::binary); if (!in) { printf("%s: failed to open %s\n",__func__, imatrix_file.c_str()); @@ -160,18 +163,33 @@ static void load_imatrix(const std::string & imatrix_file, std::unordered_map dataset_as_vec(dataset_len); + in.read(dataset_as_vec.data(), dataset_len); + imatrix_dataset.assign(dataset_as_vec.begin(), dataset_as_vec.end()); + printf("%s: imatrix dataset='%s'\n", __func__, imatrix_dataset.c_str()); + } + printf("%s: loaded %d importance matrix entries from %s computed on %d chunks\n", __func__, int(imatrix_data.size()), imatrix_file.c_str(), m_last_call); + return m_last_call; } -static void prepare_imatrix(const std::string & imatrix_file, +static int prepare_imatrix(const std::string & imatrix_file, + std::string & imatrix_dataset, const std::vector & included_weights, const std::vector & excluded_weights, std::unordered_map> & imatrix_data) { + int m_last_call = -1; if (!imatrix_file.empty()) { - load_imatrix(imatrix_file, imatrix_data); + m_last_call = load_imatrix(imatrix_file, imatrix_dataset, imatrix_data); } if (imatrix_data.empty()) { - return; + return m_last_call; } if (!excluded_weights.empty()) { for (auto& name : excluded_weights) { @@ -197,6 +215,7 @@ static void prepare_imatrix(const std::string & imatrix_file, if (!imatrix_data.empty()) { printf("%s: have %d importance matrix entries\n", __func__, int(imatrix_data.size())); } + return m_last_call; } static ggml_type parse_ggml_type(const char * arg) { @@ -211,43 +230,6 @@ static ggml_type parse_ggml_type(const char * arg) { return result; } -static bool parse_kv_override(const char * data, std::vector & overrides) { - const char* sep = strchr(data, '='); - if (sep == nullptr || sep - data >= 128) { - fprintf(stderr, "%s: malformed KV override '%s'\n", __func__, data); - return false; - } - llama_model_kv_override kvo; - std::strncpy(kvo.key, data, sep - data); - kvo.key[sep - data] = 0; - sep++; - if (strncmp(sep, "int:", 4) == 0) { - sep += 4; - kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; - kvo.int_value = std::atol(sep); - } else if (strncmp(sep, "float:", 6) == 0) { - sep += 6; - kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT; - kvo.float_value = std::atof(sep); - } else if (strncmp(sep, "bool:", 5) == 0) { - sep += 5; - kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL; - if (std::strcmp(sep, "true") == 0) { - kvo.bool_value = true; - } else if (std::strcmp(sep, "false") == 0) { - kvo.bool_value = false; - } else { - fprintf(stderr, "%s: invalid boolean value for KV override '%s'\n", __func__, data); - return false; - } - } else { - fprintf(stderr, "%s: invalid type for KV override '%s'\n", __func__, data); - return false; - } - overrides.emplace_back(std::move(kvo)); - return true; -} - int main(int argc, char ** argv) { if (argc < 3) { usage(argv[0]); @@ -316,10 +298,43 @@ int main(int argc, char ** argv) { usage(argv[0]); } + std::string imatrix_dataset; std::unordered_map> imatrix_data; - prepare_imatrix(imatrix_file, included_weights, excluded_weights, imatrix_data); + int m_last_call = prepare_imatrix(imatrix_file, imatrix_dataset, included_weights, excluded_weights, imatrix_data); if (!imatrix_data.empty()) { params.imatrix = &imatrix_data; + { + llama_model_kv_override kvo; + std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_FILE); + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR; + strncpy(kvo.val_str, imatrix_file.c_str(), 127); + kvo.val_str[127] = '\0'; + kv_overrides.emplace_back(std::move(kvo)); + } + if (!imatrix_dataset.empty()) { + llama_model_kv_override kvo; + std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_DATASET); + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR; + strncpy(kvo.val_str, imatrix_dataset.c_str(), 127); + kvo.val_str[127] = '\0'; + kv_overrides.emplace_back(std::move(kvo)); + } + + { + llama_model_kv_override kvo; + std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES); + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; + kvo.val_i64 = imatrix_data.size(); + kv_overrides.emplace_back(std::move(kvo)); + } + + if (m_last_call > 0) { + llama_model_kv_override kvo; + std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_N_CHUNKS); + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; + kvo.val_i64 = m_last_call; + kv_overrides.emplace_back(std::move(kvo)); + } } if (!kv_overrides.empty()) { kv_overrides.emplace_back(); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 48ef8ff2a..6f8ba3fc6 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2392,7 +2392,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co printf(" -n, --n-predict maximum tokens to predict (default: %d)\n", params.n_predict); printf(" --override-kv KEY=TYPE:VALUE\n"); printf(" advanced option to override model metadata by key. may be specified multiple times.\n"); - printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); + printf(" types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); printf(" -gan N, --grp-attn-n N set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`\n"); printf(" -gaw N, --grp-attn-w N set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`\n"); printf(" --chat-template JINJA_TEMPLATE\n"); @@ -2823,43 +2823,11 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, invalid_param = true; break; } - char * sep = strchr(argv[i], '='); - if (sep == nullptr || sep - argv[i] >= 128) { - fprintf(stderr, "error: Malformed KV override: %s\n", argv[i]); - invalid_param = true; - break; - } - - struct llama_model_kv_override kvo; - std::strncpy(kvo.key, argv[i], sep - argv[i]); - kvo.key[sep - argv[i]] = 0; - sep++; - if (strncmp(sep, "int:", 4) == 0) { - sep += 4; - kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; - kvo.int_value = std::atol(sep); - } else if (strncmp(sep, "float:", 6) == 0) { - sep += 6; - kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT; - kvo.float_value = std::atof(sep); - } else if (strncmp(sep, "bool:", 5) == 0) { - sep += 5; - kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL; - if (std::strcmp(sep, "true") == 0) { - kvo.bool_value = true; - } else if (std::strcmp(sep, "false") == 0) { - kvo.bool_value = false; - } else { - fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]); - invalid_param = true; - break; - } - } else { + if (!parse_kv_override(argv[i], params.kv_overrides)) { fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]); invalid_param = true; break; } - params.kv_overrides.push_back(kvo); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); server_print_usage(argv[0], default_params, default_sparams); diff --git a/llama.cpp b/llama.cpp index e5e640016..dd8b1f264 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2883,6 +2883,7 @@ namespace GGUFMeta { case LLAMA_KV_OVERRIDE_TYPE_BOOL: return "bool"; case LLAMA_KV_OVERRIDE_TYPE_INT: return "int"; case LLAMA_KV_OVERRIDE_TYPE_FLOAT: return "float"; + case LLAMA_KV_OVERRIDE_TYPE_STR: return "str"; } return "unknown"; } @@ -2894,13 +2895,16 @@ namespace GGUFMeta { __func__, override_type_to_str(ovrd->tag), ovrd->key); switch (ovrd->tag) { case LLAMA_KV_OVERRIDE_TYPE_BOOL: { - LLAMA_LOG_INFO("%s\n", ovrd->bool_value ? "true" : "false"); + LLAMA_LOG_INFO("%s\n", ovrd->val_bool ? "true" : "false"); } break; case LLAMA_KV_OVERRIDE_TYPE_INT: { - LLAMA_LOG_INFO("%" PRId64 "\n", ovrd->int_value); + LLAMA_LOG_INFO("%" PRId64 "\n", ovrd->val_i64); } break; case LLAMA_KV_OVERRIDE_TYPE_FLOAT: { - LLAMA_LOG_INFO("%.6f\n", ovrd->float_value); + LLAMA_LOG_INFO("%.6f\n", ovrd->val_f64); + } break; + case LLAMA_KV_OVERRIDE_TYPE_STR: { + LLAMA_LOG_INFO("%s\n", ovrd->val_str); } break; default: // Shouldn't be possible to end up here, but just in case... @@ -2919,7 +2923,7 @@ namespace GGUFMeta { static typename std::enable_if::value, bool>::type try_override(OT & target, const struct llama_model_kv_override * ovrd) { if (validate_override(LLAMA_KV_OVERRIDE_TYPE_BOOL, ovrd)) { - target = ovrd->bool_value; + target = ovrd->val_bool; return true; } return false; @@ -2929,7 +2933,7 @@ namespace GGUFMeta { static typename std::enable_if::value && std::is_integral::value, bool>::type try_override(OT & target, const struct llama_model_kv_override * ovrd) { if (validate_override(LLAMA_KV_OVERRIDE_TYPE_INT, ovrd)) { - target = ovrd->int_value; + target = ovrd->val_i64; return true; } return false; @@ -2939,7 +2943,7 @@ namespace GGUFMeta { static typename std::enable_if::value, bool>::type try_override(T & target, const struct llama_model_kv_override * ovrd) { if (validate_override(LLAMA_KV_OVERRIDE_TYPE_FLOAT, ovrd)) { - target = ovrd->float_value; + target = ovrd->val_f64; return true; } return false; @@ -2948,12 +2952,11 @@ namespace GGUFMeta { template static typename std::enable_if::value, bool>::type try_override(T & target, const struct llama_model_kv_override * ovrd) { - (void)target; - (void)ovrd; - if (!ovrd) { return false; } - // Currently, we should never end up here so it would be a bug if we do. - throw std::runtime_error(format("Unsupported attempt to override string type for metadata key %s\n", - ovrd ? ovrd->key : "NULL")); + if (validate_override(LLAMA_KV_OVERRIDE_TYPE_STR, ovrd)) { + target = ovrd->val_str; + return true; + } + return false; } static bool set(const gguf_context * ctx, const int k, T & target, const struct llama_model_kv_override * ovrd = nullptr) { @@ -14548,11 +14551,13 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s for (auto & o : overrides) { if (o.key[0] == 0) break; if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) { - gguf_set_val_f32(ctx_out, o.key, o.float_value); + gguf_set_val_f32(ctx_out, o.key, o.val_f64); } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) { - gguf_set_val_i32(ctx_out, o.key, o.int_value); + gguf_set_val_i32(ctx_out, o.key, o.val_i64); } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) { - gguf_set_val_bool(ctx_out, o.key, o.bool_value); + gguf_set_val_bool(ctx_out, o.key, o.val_bool); + } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) { + gguf_set_val_str(ctx_out, o.key, o.val_str); } else { LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o.key); } diff --git a/llama.h b/llama.h index 0a79fa765..8b1b15ed4 100644 --- a/llama.h +++ b/llama.h @@ -195,15 +195,19 @@ extern "C" { LLAMA_KV_OVERRIDE_TYPE_INT, LLAMA_KV_OVERRIDE_TYPE_FLOAT, LLAMA_KV_OVERRIDE_TYPE_BOOL, + LLAMA_KV_OVERRIDE_TYPE_STR, }; struct llama_model_kv_override { - char key[128]; enum llama_model_kv_override_type tag; + + char key[128]; + union { - int64_t int_value; - double float_value; - bool bool_value; + int64_t val_i64; + double val_f64; + bool val_bool; + char val_str[128]; }; };