From 242135eca42d7437ff200570cca9c07d46575012 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 11 Jan 2025 21:35:10 +0100 Subject: [PATCH] various fixes --- common/arg.cpp | 40 +++++++++++++++++----------------------- common/common.cpp | 13 ++++--------- common/common.h | 24 ++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 32 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index dcf89489e..4a9d8ecd0 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -13,12 +13,6 @@ #include #include -#if defined(LLAMA_USE_CURL) -#include -#include -#include -#endif - #include "json-schema-to-grammar.h" using json = nlohmann::ordered_json; @@ -140,21 +134,21 @@ std::string common_arg::to_string() { * - bartowski/Llama-3.2-3B-Instruct-GGUF:q4 * - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M * - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s - * Tag is optional, default to Q4_K_M if it exists + * Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo) * Return pair of (with "repo" already having tag removed) */ static std::pair common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & hf_token) { auto parts = string_split(hf_repo_with_tag, ':'); - std::string tag = parts.size() > 1 ? parts[1] : "latest"; // "latest" means checking Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo + std::string tag = parts.size() > 1 ? parts.back() : "latest"; std::string hf_repo = parts[0]; if (string_split(hf_repo, '/').size() != 2) { - throw std::invalid_argument("error: invalid HF repo format, expected /[:tag]\n"); + throw std::invalid_argument("error: invalid HF repo format, expected /[:quant]\n"); } // fetch model info from Hugging Face Hub API json model_info; - std::unique_ptr curl(curl_easy_init(), &curl_easy_cleanup); - std::unique_ptr http_headers(nullptr, &curl_slist_free_all); + curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); + curl_slist_ptr http_headers; std::string res_str; std::string url = "https://huggingface.co/v2/" + hf_repo + "/manifests/" + tag; curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); @@ -171,26 +165,27 @@ static std::pair common_get_hf_file(const std::string #endif if (!hf_token.empty()) { std::string auth_header = "Authorization: Bearer " + hf_token; - http_headers.reset(curl_slist_append(http_headers.get(), auth_header.c_str())); - // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response - http_headers.reset(curl_slist_append(http_headers.get(), "User-Agent: llama-cpp")); - http_headers.reset(curl_slist_append(http_headers.get(), "Accept: application/json")); - curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.get()); + http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str()); } + // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response + http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp"); + http_headers.ptr = curl_slist_append(http_headers.ptr, "Accept: application/json"); + curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); + CURLcode res = curl_easy_perform(curl.get()); if (res != CURLE_OK) { - throw std::runtime_error("error: cannot make GET request to Hugging Face Hub API"); + throw std::runtime_error("error: cannot make GET request to HF API"); } long res_code; curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code); if (res_code == 200) { model_info = json::parse(res_str); - } if (res_code == 401) { + } else if (res_code == 401) { throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token"); } else { - throw std::runtime_error(string_format("error: cannot get model info from Hugging Face Hub API, response code: %ld", res_code)); + throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str())); } // check response @@ -202,7 +197,6 @@ static std::pair common_get_hf_file(const std::string throw std::runtime_error("error: ggufFile does not have rfilename"); } - // TODO handle error return std::make_pair(hf_repo, gguf_file.at("rfilename")); } #else @@ -1676,7 +1670,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_env("LLAMA_ARG_MODEL_URL")); add_opt(common_arg( - {"-hf", "-hfr", "--hf-repo"}, "/[:quant]", + {"-hf", "-hfr", "--hf-repo"}, "/[:quant]", "Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.\n" "example: unsloth/phi-4-GGUF:q4_k_m\n" "(default: unused)", @@ -1686,13 +1680,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_env("LLAMA_ARG_HF_REPO")); add_opt(common_arg( {"-hff", "--hf-file"}, "FILE", - "Hugging Face model file, unused if quant is already specified in --hf-repo (default: unused)", + "Hugging Face model file. If specified, it will override the quant in --hf-repo (default: unused)", [](common_params & params, const std::string & value) { params.hf_file = value; } ).set_env("LLAMA_ARG_HF_FILE")); add_opt(common_arg( - {"-hfv", "-hfrv", "--hf-repo-v"}, "/[:quant]", + {"-hfv", "-hfrv", "--hf-repo-v"}, "/[:quant]", "Hugging Face model repository for the vocoder model (default: unused)", [](common_params & params, const std::string & value) { params.vocoder.hf_repo = value; diff --git a/common/common.cpp b/common/common.cpp index dca7ddf69..b25ef2c5e 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -52,11 +52,6 @@ #include #include #endif -#if defined(LLAMA_USE_CURL) -#include -#include -#include -#endif #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -1126,8 +1121,8 @@ static bool curl_perform_with_retry(const std::string & url, CURL * curl, int ma static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) { // Initialize libcurl - std::unique_ptr curl(curl_easy_init(), &curl_easy_cleanup); - std::unique_ptr http_headers(nullptr, &curl_slist_free_all); + curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); + curl_slist_ptr http_headers; if (!curl) { LOG_ERR("%s: error initializing libcurl\n", __func__); return false; @@ -1142,8 +1137,8 @@ static bool common_download_file(const std::string & url, const std::string & pa // Check if hf-token or bearer-token was specified if (!hf_token.empty()) { std::string auth_header = "Authorization: Bearer " + hf_token; - http_headers.reset(curl_slist_append(http_headers.get(), auth_header.c_str())); - curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.get()); + http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str()); + curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); } #if defined(_WIN32) diff --git a/common/common.h b/common/common.h index 42d75ef4b..a2c97cd51 100644 --- a/common/common.h +++ b/common/common.h @@ -8,6 +8,12 @@ #include #include +#if defined(LLAMA_USE_CURL) +#include +#include +#include +#endif + #ifdef _WIN32 #define DIRECTORY_SEPARATOR '\\' #else @@ -651,4 +657,22 @@ const char * const LLM_KV_SPLIT_NO = "split.no"; const char * const LLM_KV_SPLIT_COUNT = "split.count"; const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"; +#if defined(LLAMA_USE_CURL) +// +// CURL utils +// + +using curl_ptr = std::unique_ptr; + +// cannot use unique_ptr for curl_slist, because we cannot update without destroying the old one +struct curl_slist_ptr { + struct curl_slist * ptr = nullptr; + ~curl_slist_ptr() { + if (ptr) { + curl_slist_free_all(ptr); + } + } +}; +#endif + }