llama : auto download HF models if URL provided

This commit is contained in:
Georgi Gerganov 2024-01-02 13:19:56 +02:00
parent 32866c5edd
commit 120a1a5515
No known key found for this signature in database
GPG Key ID: BF970631944C16B7

View File

@ -3909,9 +3909,57 @@ static bool llm_load_tensors(
return true;
}
// check if the URL is a HuggingFace model, and if so, try to download it
static void hf_try_download_model(std::string & url) {
bool is_url = false;
if (url.size() > 22) {
is_url = (url.compare(0, 22, "https://huggingface.co") == 0);
}
if (!is_url) {
return;
}
// Examples:
//
// https://huggingface.co/TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF/resolve/main/mixtral-8x7b-instruct-v0.1.Q2_K.gguf
std::string basename;
basename = url.substr(url.find_last_of("/\\") + 1);
LLAMA_LOG_INFO("%s: detected URL, attempting to download %s\n", __func__, basename.c_str());
{
const std::string cmd = "wget -q --show-progress -c -O " + basename + " " + url;
LLAMA_LOG_INFO("%s: %s\n", __func__, cmd.c_str());
const int ret = system(cmd.c_str());
if (ret == 0) {
url = basename;
return;
}
}
{
const std::string cmd = "curl -C - -f -o " + basename + " -L " + url;
LLAMA_LOG_INFO("%s: %s\n", __func__, cmd.c_str());
const int ret = system(cmd.c_str());
if (ret == 0) {
url = basename;
return;
}
}
LLAMA_LOG_WARN("%s: failed to download\n", __func__);
}
// Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
static int llama_model_load(const std::string & fname, llama_model & model, const llama_model_params & params) {
static int llama_model_load(std::string fname, llama_model & model, const llama_model_params & params) {
try {
hf_try_download_model(fname);
llama_model_loader ml(fname, params.use_mmap, params.kv_overrides);
model.hparams.vocab_only = params.vocab_only;