Compare commits

...

19 Commits

Author SHA1 Message Date
compilade
403df8e353
Merge 2d79a7077c into 8db003a19d 2024-09-11 15:29:55 +03:00
Pavel Zloi
8db003a19d
py : support converting local models (#7547)
* Support of converting local models added to convert-hf-to-gguf-update.py

* Description fixed

* shutil added to imports
2024-09-11 15:29:51 +03:00
Xuan Son Nguyen
0996c5597f
llava : correct args for minicpmv-cli (#9429) 2024-09-11 12:59:13 +02:00
Xuan Son Nguyen
5bb2c5dbd2
files : remove accidentally added lora_test submodule (#9430) 2024-09-11 13:02:09 +03:00
Farbod Bijary
67155ab7f5
feat: Implements retrying logic for downloading models using --model-url flag (#9255)
* feat: Implements retrying logic for downloading models using --model-url flag

* Update common/common.cpp

Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>

* Update common/common.cpp

Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>

* apply comments

* implements a retry function to avoid duplication

* fix editorconfig

* change function name

---------

Co-authored-by: farbod <farbod.bjary82@gmail.com>
Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
Co-authored-by: slaren <slarengh@gmail.com>
Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
2024-09-11 11:22:37 +02:00
Johannes Gäßler
5af118efda
CUDA: fix --split-mode row race condition (#9413) 2024-09-11 10:22:40 +02:00
Francis Couture-Harpin
2d79a7077c quantize : use unused imatrix chunk_size with LLAMA_TRACE 2024-09-10 12:09:17 -04:00
Francis Couture-Harpin
8c13e16bb0 imatrix : allow loading mis-ordered tensors
Sums and counts tensors no longer need to be consecutive.

* imatrix : more sanity checks when loading multiple imatrix files

* imatrix : use ggml_format_name instead of std::string concatenation

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
2024-09-10 11:51:23 -04:00
Francis Couture-Harpin
2217247051 imatrix : remove unused n_entries 2024-09-09 22:35:47 -04:00
Francis Couture-Harpin
efa9186dc8 imatrix : avoid using designated initializers in C++ 2024-09-09 22:33:10 -04:00
Francis Couture-Harpin
894ed8d7b6 py : include imatrix converter requirements in toplevel requirements 2024-09-09 22:20:18 -04:00
Francis Couture-Harpin
9e6b0e9419 perplexity : revert changes 2024-09-09 22:00:37 -04:00
Francis Couture-Harpin
503630e88a py : add requirements for legacy imatrix convert script 2024-09-09 21:56:04 -04:00
Francis Couture-Harpin
d19101c9a0 imatrix : use FMA and sort tensor names 2024-09-08 11:03:59 -04:00
Francis Couture-Harpin
3ad0603c65 Merge branch 'master' into compilade/imatrix-batched-chunks 2024-09-08 10:05:08 -04:00
Francis Couture-Harpin
c8ab6a3ba3 imatrix : fix conversion problems 2024-09-08 10:04:01 -04:00
Francis Couture-Harpin
3de9300c37 imatrix : use GGUF to store imatrix data 2024-09-06 17:17:25 -04:00
Francis Couture-Harpin
347247a24e imatrix : fix segfault when using a single chunk per batch 2024-08-20 15:35:56 -04:00
Francis Couture-Harpin
bce54642c8 imatrix : allow processing multiple chunks per batch
* perplexity : simplify filling the batch
2024-08-20 15:17:24 -04:00
11 changed files with 546 additions and 195 deletions

View File

@ -941,11 +941,37 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p
#ifdef LLAMA_USE_CURL
#define CURL_MAX_RETRY 3
#define CURL_RETRY_DELAY_SECONDS 2
static bool starts_with(const std::string & str, const std::string & prefix) {
// While we wait for C++20's std::string::starts_with...
return str.rfind(prefix, 0) == 0;
}
static bool curl_perform_with_retry(const std::string& url, CURL* curl, int max_attempts, int retry_delay_seconds) {
int remaining_attempts = max_attempts;
while (remaining_attempts > 0) {
fprintf(stderr, "%s: Trying to download from %s (attempt %d of %d)...\n", __func__ , url.c_str(), max_attempts - remaining_attempts + 1, max_attempts);
CURLcode res = curl_easy_perform(curl);
if (res == CURLE_OK) {
return true;
}
int exponential_backoff_delay = std::pow(retry_delay_seconds, max_attempts - remaining_attempts) * 1000;
fprintf(stderr, "%s: curl_easy_perform() failed: %s, retrying after %d milliseconds...\n", __func__, curl_easy_strerror(res), exponential_backoff_delay);
remaining_attempts--;
std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
}
fprintf(stderr, "%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts);
return false;
}
static bool llama_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {
// Initialize libcurl
@ -1049,9 +1075,8 @@ static bool llama_download_file(const std::string & url, const std::string & pat
curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast<CURLOPT_HEADERFUNCTION_PTR>(header_callback));
curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
CURLcode res = curl_easy_perform(curl.get());
if (res != CURLE_OK) {
fprintf(stderr, "%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res));
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS);
if (!was_perform_successful) {
return false;
}
@ -1126,11 +1151,10 @@ static bool llama_download_file(const std::string & url, const std::string & pat
};
// start the download
fprintf(stderr, "%s: downloading from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__,
llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str());
auto res = curl_easy_perform(curl.get());
if (res != CURLE_OK) {
fprintf(stderr, "%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res));
fprintf(stderr, "%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__,
llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str());
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS);
if (!was_perform_successful) {
return false;
}

View File

@ -31,6 +31,7 @@ import re
import requests
import sys
import json
import shutil
from hashlib import sha256
from enum import IntEnum, auto
@ -125,12 +126,27 @@ def download_model(model):
if tokt == TOKENIZER_TYPE.UGM:
files.append("spiece.model")
for file in files:
save_path = f"models/tokenizers/{name}/{file}"
if os.path.isfile(save_path):
logger.info(f"{name}: File {save_path} already exists - skipping")
continue
download_file_with_auth(f"{repo}/resolve/main/{file}", token, save_path)
if os.path.isdir(repo):
# If repo is a path on the file system, copy the directory
for file in files:
src_path = os.path.join(repo, file)
dst_path = f"models/tokenizers/{name}/{file}"
if os.path.isfile(dst_path):
logger.info(f"{name}: File {dst_path} already exists - skipping")
continue
if os.path.isfile(src_path):
shutil.copy2(src_path, dst_path)
logger.info(f"{name}: Copied {src_path} to {dst_path}")
else:
logger.warning(f"{name}: Source file {src_path} does not exist")
else:
# If repo is a URL, download the files
for file in files:
save_path = f"models/tokenizers/{name}/{file}"
if os.path.isfile(save_path):
logger.info(f"{name}: File {save_path} already exists - skipping")
continue
download_file_with_auth(f"{repo}/resolve/main/{file}", token, save_path)
for model in models:

View File

@ -0,0 +1,122 @@
#!/usr/bin/env python3
from __future__ import annotations
import os
import sys
import logging
import argparse
from typing import Any
from pathlib import Path
from dataclasses import dataclass
import numpy as np
import numpy.typing as npt
if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
import gguf
logger = logging.getLogger("imatrix-to-gguf")
class IMatrixWriter(gguf.GGUFWriter):
def add_architecture(self) -> None:
# no arch is stored in imatrix files
pass
@dataclass
class IMatrixEntry:
values: np.ndarray[Any, np.dtype[np.float32]]
counts: np.ndarray[Any, np.dtype[np.float32]]
class IMatrixReader:
chunk_size: int = 512 # guess
offset: int = 0
data: np.ndarray[Any, np.dtype[np.uint8]]
n_enties: int
entries: dict[str, IMatrixEntry]
chunk_count: int
dataset: str
def _get(self, dtype: npt.DTypeLike, count: int = 1) -> npt.NDArray[Any]:
count = int(count)
itemsize = int(np.empty([], dtype=dtype).itemsize)
offset = self.offset
self.offset = offset + itemsize * count
return self.data[offset:self.offset].view(dtype=dtype)[:count]
def __init__(self, imatrix: Path):
self.offset = 0
self.entries = {}
self.data = np.memmap(imatrix)
n_entries = self._get(np.int32).item()
assert n_entries >= 0
for _ in range(n_entries):
len = self._get(np.int32).item()
name = self._get(np.uint8, len).tobytes().decode("utf-8")
ncall = self._get(np.int32).item()
nval = self._get(np.int32).item()
data = self._get(np.float32, nval)
assert name not in self.entries, f"duplicated name: {name!r}"
self.entries[name] = IMatrixEntry(data * np.float32(self.chunk_size), np.array([ncall * self.chunk_size], dtype=np.float32))
self.chunk_count = self._get(np.int32).item()
dataset_len = self._get(np.int32).item()
self.dataset = self._get(np.uint8, dataset_len).tobytes().decode("utf-8")
def to_writer(self, outfile: Path) -> IMatrixWriter:
writer = IMatrixWriter(path=outfile, arch="")
writer.add_type(gguf.GGUFType.IMATRIX)
writer.add_key_value(gguf.Keys.IMatrix.CHUNK_COUNT, self.chunk_count, gguf.GGUFValueType.UINT32)
writer.add_key_value(gguf.Keys.IMatrix.CHUNK_SIZE, self.chunk_size, gguf.GGUFValueType.UINT32)
writer.add_key_value(gguf.Keys.IMatrix.DATASET, self.dataset, gguf.GGUFValueType.STRING)
for name, entry in self.entries.items():
writer.add_tensor(name + ".sums", entry.values)
writer.add_tensor(name + ".counts", entry.counts)
return writer
def parse_args():
parser = argparse.ArgumentParser(
description="Convert an old imatrix.dat file to a GGUF compatible file")
parser.add_argument(
"--outfile", type=Path,
help="path to write to; default: based on input.",
)
parser.add_argument(
"--verbose", action="store_true",
help="increase output verbosity",
)
parser.add_argument(
"imatrix", type=Path,
help="path to an imatrix file",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
if args.outfile is None:
input_file: Path = args.imatrix
if input_file.suffix != ".gguf":
args.outfile = input_file.with_suffix(".gguf")
if args.outfile.exists():
logger.error(f"default file exists, specify with --outfile to overwrite: {args.outfile}")
exit(1)
writer = IMatrixReader(args.imatrix).to_writer(args.outfile)
writer.write_header_to_file(args.outfile)
writer.write_kv_data_to_file()
writer.write_tensors_to_file()

View File

@ -6,12 +6,11 @@
#include <cstdio>
#include <cstring>
#include <ctime>
#include <sstream>
#include <thread>
#include <mutex>
#include <vector>
#include <fstream>
#include <unordered_map>
#include <map>
#include <algorithm>
#if defined(_MSC_VER)
@ -21,16 +20,27 @@
static void print_usage(int, char ** argv) {
LOG_TEE("\nexample usage:\n");
LOG_TEE("\n %s \\\n"
" -m model.gguf -f some-text.txt [-o imatrix.dat] [--process-output] [--verbosity 1] \\\n"
" -m model.gguf -f some-text.txt [-o imatrix.gguf] [--process-output] [--verbosity 1] \\\n"
" [--no-ppl] [--chunk 123] [--output-frequency 10] [--save-frequency 0] \\\n"
" [--in-file imatrix-prev-0.dat --in-file imatrix-prev-1.dat ...]\n" , argv[0]);
" [--in-file imatrix-prev-0.gguf --in-file imatrix-prev-1.gguf ...]\n" , argv[0]);
LOG_TEE("\n");
}
static bool str_remove_suffix(std::string & str, const std::string & suffix) {
bool has_suffix = str.size() >= suffix.size() && str.compare(str.size() - suffix.size(), str.size(), suffix) == 0;
if (has_suffix) {
str = str.substr(0, str.size() - suffix.size());
}
return has_suffix;
}
static const char * const LLM_KV_IMATRIX_DATASET = "imatrix.dataset";
static const char * const LLM_KV_IMATRIX_CHUNK_COUNT = "imatrix.chunk_count";
static const char * const LLM_KV_IMATRIX_CHUNK_SIZE = "imatrix.chunk_size";
struct Stats {
std::vector<float> values;
std::vector<int> counts;
int ncall = 0;
std::vector<float> values;
std::vector<int64_t> counts;
};
class IMatrixCollector {
@ -38,13 +48,13 @@ public:
IMatrixCollector() = default;
void set_params(gpt_params params) { m_params = std::move(params); }
bool collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data);
void save_imatrix(int ncall = -1) const;
void save_imatrix(int32_t n_chunk = -1) const;
bool load_imatrix(const char * file_name);
private:
std::unordered_map<std::string, Stats> m_stats;
gpt_params m_params;
std::mutex m_mutex;
int m_last_call = 0;
int32_t m_last_chunk = 0;
std::vector<float> m_src1_data;
std::vector<char> m_ids; // the expert ids from ggml_mul_mat_id
};
@ -118,18 +128,24 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
auto & e = m_stats[wname];
++e.ncall;
if (e.counts.size() == 1 && n_as > 1) {
// broadcast, when loading an old imatrix
e.counts.resize(n_as, e.counts[0]);
}
if (e.values.empty()) {
e.values.resize(src1->ne[0]*n_as, 0);
e.counts.resize(src1->ne[0]*n_as, 0);
e.counts.resize(n_as, 0);
}
else if (e.values.size() != (size_t)src1->ne[0]*n_as) {
fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as);
exit(1); //GGML_ABORT("fatal error");
}
else if (e.counts.size() != (size_t)n_as) {
fprintf(stderr, "Oops: inconsistent expert count for %s (%d vs %d)\n", wname.c_str(), (int)e.counts.size(), (int)n_as);
exit(1); //GGML_ABORT("fatal error");
}
if (m_params.verbosity > 1) {
printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[2], (int)src1->type);
printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_chunk, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[2], (int)src1->type);
}
// loop over all possible experts, regardless if they are used or not in the batch
for (int ex = 0; ex < n_as; ++ex) {
@ -147,23 +163,26 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
const int64_t i12 = row;
const float * x = (const float *)((const char *)data + i11*src1->nb[1] + i12*src1->nb[2]);
e.counts[ex]++;
for (int j = 0; j < (int)src1->ne[0]; ++j) {
e.values[e_start + j] += x[j]*x[j];
e.counts[e_start + j]++;
if (!std::isfinite(e.values[e_start + j])) {
fprintf(stderr, "%f detected in %s\n", e.values[e_start + j], wname.c_str());
e.values[e_start + j] = std::fma(x[j], x[j], e.values[e_start + j]);
if (!std::isfinite((float)e.values[e_start + j])) {
fprintf(stderr, "%f detected in %s\n", (float)e.values[e_start + j], wname.c_str());
exit(1);
}
}
}
}
if (e.ncall > m_last_call) {
m_last_call = e.ncall;
if (m_last_call % m_params.n_out_freq == 0) {
const int32_t n_chunk = e.counts[ex] / (m_params.n_ctx / m_params.n_parallel);
if (n_chunk > m_last_chunk) {
const int32_t chunk_step = n_chunk - m_last_chunk;
m_last_chunk = n_chunk;
if ((m_last_chunk % m_params.n_out_freq) / chunk_step == 0) {
save_imatrix();
}
if (m_params.n_save_freq > 0 && m_last_call%m_params.n_save_freq == 0) {
save_imatrix(m_last_call);
if (m_params.n_save_freq > 0 && (m_last_chunk % m_params.n_save_freq) / chunk_step == 0) {
save_imatrix(m_last_chunk);
}
}
}
@ -171,34 +190,40 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
auto & e = m_stats[wname];
if (e.values.empty()) {
e.values.resize(src1->ne[0], 0);
e.counts.resize(src1->ne[0], 0);
e.counts.resize(1, 0);
}
else if (e.values.size() != (size_t)src1->ne[0]) {
fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]);
exit(1); //GGML_ABORT("fatal error");
}
++e.ncall;
if (m_params.verbosity > 1) {
printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type);
else if (e.counts.size() != 1) {
fprintf(stderr, "Oops: inconsistent expert count for %s (%d vs %d)\n", wname.c_str(), (int)e.counts.size(), 1);
exit(1); //GGML_ABORT("fatal error");
}
if (m_params.verbosity > 1) {
printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_chunk, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type);
}
// TODO: higher dimensions
for (int row = 0; row < (int)src1->ne[1]; ++row) {
const float * x = data + row * src1->ne[0];
e.counts[0]++;
for (int j = 0; j < (int)src1->ne[0]; ++j) {
e.values[j] += x[j]*x[j];
e.counts[j]++;
if (!std::isfinite(e.values[j])) {
fprintf(stderr, "%f detected in %s\n", e.values[j], wname.c_str());
e.values[j] = std::fma(x[j], x[j], e.values[j]);
if (!std::isfinite((float)e.values[j])) {
fprintf(stderr, "%f detected in %s\n", (float)e.values[j], wname.c_str());
exit(1);
}
}
}
if (e.ncall > m_last_call) {
m_last_call = e.ncall;
if (m_last_call % m_params.n_out_freq == 0) {
const int32_t n_chunk = e.counts[0] / (m_params.n_ctx / m_params.n_parallel);
if (n_chunk > m_last_chunk) {
const int32_t chunk_step = n_chunk - m_last_chunk;
m_last_chunk = n_chunk;
if ((m_last_chunk % m_params.n_out_freq) / chunk_step == 0) {
save_imatrix();
}
if (m_params.n_save_freq > 0 && m_last_call%m_params.n_save_freq == 0) {
save_imatrix(m_last_call);
if (m_params.n_save_freq > 0 && (m_last_chunk % m_params.n_save_freq) / chunk_step == 0) {
save_imatrix(m_last_chunk);
}
}
}
@ -206,22 +231,22 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
return true;
}
void IMatrixCollector::save_imatrix(int ncall) const {
void IMatrixCollector::save_imatrix(int32_t n_chunk) const {
auto fname = m_params.out_file;
if (fname.empty()) {
fname = "imatrix.dat";
fname = "imatrix.gguf";
}
if (ncall > 0) {
if (n_chunk > 0) {
fname += ".at_";
fname += std::to_string(ncall);
fname += std::to_string(n_chunk);
}
// avoid writing imatrix entries that do not have full data
// this can happen with MoE models where some of the experts end up not being exercised by the provided training data
int n_entries = 0;
std::vector<std::string> to_store;
size_t data_size = 0;
bool is_first = true; // for printing
for (const auto & kv : m_stats) {
@ -253,102 +278,158 @@ void IMatrixCollector::save_imatrix(int ncall) const {
continue;
}
n_entries++;
to_store.push_back(kv.first);
data_size += GGML_PAD(ggml_tensor_overhead() + sizeof(float) * kv.second.values.size(), GGML_MEM_ALIGN);
data_size += GGML_PAD(ggml_tensor_overhead() + sizeof(float) * kv.second.counts.size(), GGML_MEM_ALIGN);
}
if (to_store.size() < m_stats.size()) {
fprintf(stderr, "%s: warning: storing only %zu out of %zu entries\n", __func__, to_store.size(), m_stats.size());
}
std::ofstream out(fname, std::ios::binary);
out.write((const char *) &n_entries, sizeof(n_entries));
// deterministic tensor name order
std::sort(to_store.begin(), to_store.end());
struct ggml_init_params params = {
/* .mem_size = */ data_size,
/* .mem_buffer = */ NULL,
/* .no_alloc = */ false,
};
struct ggml_context * ctx = ggml_init(params);
struct gguf_context * ctx_gguf = gguf_init_empty();
gguf_set_val_str(ctx_gguf, "general.type", "imatrix");
// Write the input filename to later on specify it in quantize
gguf_set_val_str(ctx_gguf, LLM_KV_IMATRIX_DATASET, m_params.prompt_file.c_str());
// Write the number of chunks the matrix was computed with
gguf_set_val_u32(ctx_gguf, LLM_KV_IMATRIX_CHUNK_COUNT, m_last_chunk);
gguf_set_val_u32(ctx_gguf, LLM_KV_IMATRIX_CHUNK_SIZE, m_params.n_ctx / m_params.n_parallel);
for (const auto & name : to_store) {
const auto & stat = m_stats.at(name);
int len = name.size();
out.write((const char *) &len, sizeof(len));
out.write(name.c_str(), len);
out.write((const char *) &stat.ncall, sizeof(stat.ncall));
int nval = stat.values.size();
out.write((const char *) &nval, sizeof(nval));
const int32_t nval = (int32_t) stat.values.size();
const int32_t nmat = (int32_t) stat.counts.size();
if (nval > 0) {
std::vector<float> tmp(nval);
for (int i = 0; i < nval; i++) {
tmp[i] = (stat.values[i] / static_cast<float>(stat.counts[i])) * static_cast<float>(stat.ncall);
struct ggml_tensor * sums = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nval / nmat, nmat);
struct ggml_tensor * counts = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, nmat);
ggml_format_name(sums, "%s.sums", name.c_str());
ggml_format_name(counts, "%s.counts", name.c_str());
for (int32_t j = 0; j < nval; ++j) {
((float *) sums->data)[j] = (float) stat.values[j];
}
out.write((const char*)tmp.data(), nval*sizeof(float));
for (int32_t j = 0; j < nmat; ++j) {
((float *) counts->data)[j] = (float) stat.counts[j];
}
gguf_add_tensor(ctx_gguf, sums);
gguf_add_tensor(ctx_gguf, counts);
}
}
// Write the number of call the matrix was computed with
out.write((const char *) &m_last_call, sizeof(m_last_call));
// Write the input filename at the end of the file to later on specify it in quantize
{
int len = m_params.prompt_file.size();
out.write((const char *) &len, sizeof(len));
out.write(m_params.prompt_file.c_str(), len);
}
gguf_write_to_file(ctx_gguf, fname.c_str(), false);
if (m_params.verbosity > 0) {
fprintf(stderr, "\n%s: stored collected data after %d chunks in %s\n", __func__, m_last_call, fname.c_str());
fprintf(stderr, "\n%s: stored collected data after %d chunks in %s\n", __func__, m_last_chunk, fname.c_str());
}
gguf_free(ctx_gguf);
ggml_free(ctx);
}
bool IMatrixCollector::load_imatrix(const char * fname) {
std::ifstream in(fname, std::ios::binary);
if (!in) {
printf("%s: failed to open %s\n",__func__, fname);
bool IMatrixCollector::load_imatrix(const char * file_name) {
struct ggml_context * ctx = nullptr;
struct gguf_init_params meta_gguf_params = {
/* .no_alloc = */ false, // the data is needed
/* .ctx = */ &ctx,
};
struct gguf_context * ctx_gguf = gguf_init_from_file(file_name, meta_gguf_params);
if (!ctx_gguf) {
return false;
}
int n_entries;
in.read((char*)&n_entries, sizeof(n_entries));
if (in.fail() || n_entries < 1) {
printf("%s: no data in file %s\n", __func__, fname);
const int32_t n_entries = gguf_get_n_tensors(ctx_gguf);
if (n_entries < 1) {
fprintf(stderr, "%s: no data in file %s\n", __func__, file_name);
gguf_free(ctx_gguf);
ggml_free(ctx);
return false;
}
for (int i = 0; i < n_entries; ++i) {
int len; in.read((char *)&len, sizeof(len));
std::vector<char> name_as_vec(len+1);
in.read((char *)name_as_vec.data(), len);
if (in.fail()) {
printf("%s: failed reading name for entry %d from %s\n",__func__,i+1, fname);
const std::string sums_suffix{".sums"};
const std::string counts_suffix{".counts"};
// Could re-use m_stats instead, but this allows
// checking for completeness of *each* loaded imatrix file
// and also makes it easier to re-use a similar implementation in quantize.cpp
// Using an ordered map to get a deterministic iteration order.
std::map<std::string, std::pair<struct ggml_tensor *, struct ggml_tensor *>> sums_counts_for;
for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
std::string name = cur->name;
if (name.empty()) { continue; }
if (str_remove_suffix(name, sums_suffix)) {
// sums
sums_counts_for[name].first = cur;
} else if (str_remove_suffix(name, counts_suffix)) {
// counts
sums_counts_for[name].second = cur;
} else {
fprintf(stderr, "%s: invalid imatrix tensor name: %s\n", __func__, name.c_str());
gguf_free(ctx_gguf);
ggml_free(ctx);
return false;
}
name_as_vec[len] = 0;
std::string name{name_as_vec.data()};
auto & e = m_stats[std::move(name)];
int ncall;
in.read((char*)&ncall, sizeof(ncall));
int nval;
in.read((char *)&nval, sizeof(nval));
if (in.fail() || nval < 1) {
printf("%s: failed reading number of values for entry %d\n",__func__,i);
m_stats = {};
}
for (const auto & sc : sums_counts_for) {
const std::string & name = sc.first;
const struct ggml_tensor * sums = sc.second.first;
const struct ggml_tensor * counts = sc.second.second;
if (!sums || !counts) {
fprintf(stderr, "%s: mismatched sums and counts for %s\n", __func__, name.c_str());
gguf_free(ctx_gguf);
ggml_free(ctx);
return false;
}
auto & e = m_stats[name];
int64_t nval = ggml_nelements(sums);
if (e.values.empty()) {
e.values.resize(nval, 0);
e.counts.resize(nval, 0);
}
std::vector<float> tmp(nval);
in.read((char*)tmp.data(), nval*sizeof(float));
if (in.fail()) {
printf("%s: failed reading data for entry %d\n",__func__,i);
m_stats = {};
} else if ((size_t) nval != e.values.size()) {
fprintf(stderr, "%s: mismatched sums size for %s: %zu != %zu\n", __func__, name.c_str(), (size_t) nval, e.values.size());
gguf_free(ctx_gguf);
ggml_free(ctx);
return false;
}
// Recreate the state as expected by save_imatrix(), and corerct for weighted sum.
for (int i = 0; i < nval; i++) {
e.values[i] += tmp[i];
e.counts[i] += ncall;
int64_t ncounts = ggml_nelements(counts);
if (e.counts.empty()) {
e.counts.resize(ncounts, 0);
} else if (e.counts.size() == 1 && ncounts > 1) {
// broadcast, when loading an old imatrix
e.counts.resize(ncounts, e.counts[0]);
} else if ((size_t) ncounts != e.counts.size()) {
fprintf(stderr, "%s: mismatched counts size for %s: %zu != %zu\n", __func__, name.c_str(), (size_t) ncounts, e.counts.size());
gguf_free(ctx_gguf);
ggml_free(ctx);
return false;
}
e.ncall += ncall;
// Recreate the state as expected by save_imatrix()
for (int64_t j = 0; j < nval; j++) {
e.values[j] += ((const float *) sums->data)[j];
}
for (int64_t j = 0; j < ncounts; j++) {
e.counts[j] += std::lround(((const float *) counts->data)[j]);
}
}
gguf_free(ctx_gguf);
ggml_free(ctx);
return true;
}
@ -431,10 +512,9 @@ static void process_logits(
}
}
static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
static bool compute_imatrix(llama_context * ctx, const gpt_params & params, const int32_t n_ctx) {
const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx)));
const int n_ctx = llama_n_ctx(ctx);
auto tim1 = std::chrono::high_resolution_clock::now();
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
@ -478,22 +558,28 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
double nll = 0.0;
double nll2 = 0.0;
fprintf(stderr, "%s: computing over %d chunks with batch_size %d\n", __func__, n_chunk, n_batch);
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
const int n_seq = std::max(1, n_batch / n_ctx);
GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0);
GGML_ASSERT(params.n_ctx == n_seq * n_ctx);
llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1);
std::vector<float> logits;
if (params.compute_ppl && num_batches > 1) {
logits.reserve((size_t)n_ctx * n_vocab);
}
for (int i = 0; i < n_chunk; ++i) {
fprintf(stderr, "%s: computing over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq);
for (int i = 0; i < n_chunk; i += n_seq) {
const int start = i * n_ctx;
const int end = start + n_ctx;
std::vector<float> logits;
const int n_seq_batch = std::min(n_seq, n_chunk - i);
const auto t_start = std::chrono::high_resolution_clock::now();
@ -504,35 +590,50 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch);
// save original token and restore it after eval
const auto token_org = tokens[batch_start];
// clear the batch
llama_batch_clear(batch);
// add BOS token for the first batch of each chunk
if (add_bos && j == 0) {
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
for (int seq = 0; seq < n_seq_batch; seq++) {
int seq_start = batch_start + seq*n_ctx;
// save original token and restore it after eval
const auto token_org = tokens[seq_start];
// add BOS token for the first batch of each chunk
if (add_bos && j == 0) {
tokens[seq_start] = llama_token_bos(llama_get_model(ctx));
}
for (int k = 0; k < batch_size; ++k) {
// NOTE: specifying all logits to get activations for the output.weight tensor
// and also for the perplexity calculation.
// TODO: only get outputs when (params.process_output || params.compute_ppl)
// (not possible when this skips FFN computation of the last layer)
llama_batch_add(batch, tokens[seq_start + k], j*n_batch + k, { seq }, true);
}
// restore the original token in case it was set to BOS
tokens[seq_start] = token_org;
}
// TODO: use batch.logits to save computations instead of relying on logits_all == true
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
if (llama_decode(ctx, batch)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}
// restore the original token in case it was set to BOS
tokens[batch_start] = token_org;
if (params.compute_ppl && num_batches > 1) {
const auto * batch_logits = llama_get_logits(ctx);
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
}
}
const auto t_end = std::chrono::high_resolution_clock::now();
if (i == 0) {
llama_synchronize(ctx);
const auto t_end = std::chrono::high_resolution_clock::now();
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
int total_seconds = (int)(t_total * n_chunk);
int total_seconds = (int)(t_total*n_chunk/n_seq);
if (total_seconds >= 60*60) {
fprintf(stderr, "%d hours ", total_seconds / (60*60));
total_seconds = total_seconds % (60*60);
@ -542,12 +643,21 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
if (params.compute_ppl) {
const int first = n_ctx/2;
const auto all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
count += n_ctx - first - 1;
for (int seq = 0; seq < n_seq_batch; seq++) {
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx);
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
process_logits(n_vocab, all_logits + first*n_vocab,
tokens_data, n_ctx - 1 - first,
workers, nll, nll2,
logit_history.data() + start + seq*n_ctx + first,
prob_history.data() + start + seq*n_ctx + first);
count += n_ctx - first - 1;
printf("[%d]%.4lf,", i + seq + 1, std::exp(nll / count));
}
fflush(stdout);
logits.clear();
@ -582,7 +692,22 @@ int main(int argc, char ** argv) {
return 1;
}
params.n_batch = std::min(params.n_batch, params.n_ctx);
const int32_t n_ctx = params.n_ctx;
if (n_ctx <= 0) {
fprintf(stderr, "%s: imatrix tool requires '--ctx-size' > 0\n", __func__);
return 1;
}
{
const int32_t n_seq = std::max(1, params.n_batch / n_ctx);
const int32_t n_kv = n_seq * n_ctx;
params.n_parallel = n_seq;
params.n_ctx = n_kv;
params.n_batch = std::min(params.n_batch, n_kv);
}
g_collector.set_params(params);
@ -630,7 +755,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s\n", gpt_params_get_system_info(params).c_str());
}
if (!compute_imatrix(ctx, params)) {
if (!compute_imatrix(ctx, params, n_ctx)) {
return 1;
}

View File

@ -18,8 +18,8 @@ struct llava_context {
};
static void show_additional_info(int /*argc*/, char ** argv) {
LOG_TEE("\n example usage: %s -m <llava-v1.5-7b/ggml-model-q5_k.gguf> --mmproj <llava-v1.5-7b/mmproj-model-f16.gguf> --image <path/to/an/image.jpg> --image <path/to/another/image.jpg> [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]);
LOG_TEE(" note: a lower temperature value like 0.1 is recommended for better quality.\n");
LOG_TEE("\nexample usage:\n\n%s -m <llava-v1.5-7b/ggml-model-q5_k.gguf> --mmproj <llava-v1.5-7b/mmproj-model-f16.gguf> --image <path/to/an/image.jpg> --image <path/to/another/image.jpg> [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]);
LOG_TEE("\nnote: a lower temperature value like 0.1 is recommended for better quality.\n");
}
static void llama_log_callback_logTee(ggml_log_level level, const char * text, void * user_data) {
@ -255,7 +255,7 @@ int main(int argc, char ** argv) {
gpt_params params;
if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON, show_additional_info)) {
if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, show_additional_info)) {
return 1;
}

View File

@ -6,8 +6,7 @@
#include <vector>
#include <string>
#include <unordered_map>
#include <fstream>
#include <cmath>
#include <map>
struct quant_option {
std::string name;
@ -63,6 +62,11 @@ static const char * const LLM_KV_QUANTIZE_IMATRIX_DATASET = "quantize.imatrix
static const char * const LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES = "quantize.imatrix.entries_count";
static const char * const LLM_KV_QUANTIZE_IMATRIX_N_CHUNKS = "quantize.imatrix.chunks_count";
// TODO: share with imatrix.cpp
static const char * const LLM_KV_IMATRIX_DATASET = "imatrix.dataset";
static const char * const LLM_KV_IMATRIX_CHUNK_COUNT = "imatrix.chunk_count";
static const char * const LLM_KV_IMATRIX_CHUNK_SIZE = "imatrix.chunk_size";
static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftype, std::string & ftype_str_out) {
std::string ftype_str;
@ -122,67 +126,114 @@ static void usage(const char * executable) {
exit(1);
}
// TODO: share with implementation in imatrix.cpp
static bool str_remove_suffix(std::string & str, const std::string & suffix) {
bool has_suffix = str.size() >= suffix.size() && str.compare(str.size() - suffix.size(), str.size(), suffix) == 0;
if (has_suffix) {
str = str.substr(0, str.size() - suffix.size());
}
return has_suffix;
}
static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_dataset, std::unordered_map<std::string, std::vector<float>> & imatrix_data) {
std::ifstream in(imatrix_file.c_str(), std::ios::binary);
if (!in) {
printf("%s: failed to open %s\n",__func__, imatrix_file.c_str());
struct ggml_context * ctx = nullptr;
struct gguf_init_params meta_gguf_params = {
/* .no_alloc = */ false, // the data is needed
/* .ctx = */ &ctx,
};
struct gguf_context * ctx_gguf = gguf_init_from_file(imatrix_file.c_str(), meta_gguf_params);
if (!ctx_gguf) {
fprintf(stderr, "%s: if this is an older imatrix file, make sure to convert it to the GGUF-based imatrix format\n", __func__);
exit(1);
}
int n_entries;
in.read((char *)&n_entries, sizeof(n_entries));
if (in.fail() || n_entries < 1) {
printf("%s: no data in file %s\n", __func__, imatrix_file.c_str());
const int32_t n_entries = gguf_get_n_tensors(ctx_gguf);
if (n_entries < 1) {
fprintf(stderr, "%s: no data in file %s\n", __func__, imatrix_file.c_str());
gguf_free(ctx_gguf);
ggml_free(ctx);
exit(1);
}
for (int i = 0; i < n_entries; ++i) {
int len; in.read((char *)&len, sizeof(len));
std::vector<char> name_as_vec(len+1);
in.read((char *)name_as_vec.data(), len);
if (in.fail()) {
printf("%s: failed reading name for entry %d from %s\n", __func__, i+1, imatrix_file.c_str());
const int dataset_idx = gguf_find_key(ctx_gguf, LLM_KV_IMATRIX_DATASET);
const int chunk_count_idx = gguf_find_key(ctx_gguf, LLM_KV_IMATRIX_CHUNK_COUNT);
const int chunk_size_idx = gguf_find_key(ctx_gguf, LLM_KV_IMATRIX_CHUNK_SIZE);
if (dataset_idx < 0 || chunk_count_idx < 0 || chunk_size_idx < 0) {
fprintf(stderr, "%s: missing imatrix metadata in file %s\n", __func__, imatrix_file.c_str());
gguf_free(ctx_gguf);
ggml_free(ctx);
exit(1);
}
const uint32_t chunk_size = gguf_get_val_u32(ctx_gguf, chunk_size_idx);
const std::string sums_suffix{".sums"};
const std::string counts_suffix{".counts"};
// Using an ordered map to get a deterministic iteration order.
std::map<std::string, std::pair<struct ggml_tensor *, struct ggml_tensor *>> sums_counts_for;
for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
std::string name = cur->name;
if (name.empty()) { continue; }
if (str_remove_suffix(name, sums_suffix)) {
// sums
sums_counts_for[name].first = cur;
} else if (str_remove_suffix(name, counts_suffix)) {
// counts
sums_counts_for[name].second = cur;
} else {
fprintf(stderr, "%s: invalid imatrix tensor name: %s\n", __func__, name.c_str());
gguf_free(ctx_gguf);
ggml_free(ctx);
exit(1);
}
name_as_vec[len] = 0;
std::string name{name_as_vec.data()};
}
for (const auto & sc : sums_counts_for) {
const std::string & name = sc.first;
const struct ggml_tensor * sums = sc.second.first;
const struct ggml_tensor * counts = sc.second.second;
if (!sums || !counts) {
fprintf(stderr, "%s: mismatched sums and counts for %s\n", __func__, name.c_str());
gguf_free(ctx_gguf);
ggml_free(ctx);
exit(1);
}
const int64_t ne0 = sums->ne[0];
const int64_t ne1 = sums->ne[1];
auto & e = imatrix_data[name];
int ncall;
in.read((char *)&ncall, sizeof(ncall));
int nval;
in.read((char *)&nval, sizeof(nval));
if (in.fail() || nval < 1) {
printf("%s: failed reading number of values for entry %d\n", __func__, i);
imatrix_data = {};
exit(1);
e.resize(ggml_nelements(sums));
float max_count = 0.0f;
for (int64_t j = 0; j < ne1; ++j) {
const float count = ((const float *) counts->data)[j];
for (int64_t i = 0; i < ne0; ++i) {
e[j*ne0 + i] = ((const float *) sums->data)[j*ne0 + i] / count;
}
if (count > max_count) {
max_count = count;
}
}
e.resize(nval);
in.read((char *)e.data(), nval*sizeof(float));
if (in.fail()) {
printf("%s: failed reading data for entry %d\n", __func__, i);
imatrix_data = {};
exit(1);
}
if (ncall > 0) {
for (auto& v : e) v /= ncall;
}
if (getenv("LLAMA_TRACE")) {
printf("%s: loaded data (size = %6d, ncall = %6d) for '%s'\n", __func__, int(e.size()), ncall, name.c_str());
printf("%s: loaded data (size = %6d, n_tokens = %6d, n_chunks = %6d) for '%s'\n", __func__, int(e.size()), int(max_count), int(max_count / chunk_size), name.c_str());
}
}
// latest imatrix version contains the dataset filename at the end of the file
int m_last_call = 0;
if (in.peek() != EOF) {
in.read((char *)&m_last_call, sizeof(m_last_call));
int dataset_len;
in.read((char *)&dataset_len, sizeof(dataset_len));
std::vector<char> dataset_as_vec(dataset_len);
in.read(dataset_as_vec.data(), dataset_len);
imatrix_dataset.assign(dataset_as_vec.begin(), dataset_as_vec.end());
printf("%s: imatrix dataset='%s'\n", __func__, imatrix_dataset.c_str());
}
printf("%s: loaded %d importance matrix entries from %s computed on %d chunks\n", __func__, int(imatrix_data.size()), imatrix_file.c_str(), m_last_call);
return m_last_call;
int m_last_chunk = gguf_get_val_u32(ctx_gguf, chunk_count_idx);
imatrix_dataset = gguf_get_val_str(ctx_gguf, dataset_idx);
printf("%s: imatrix dataset='%s'\n", __func__, imatrix_dataset.c_str());
printf("%s: loaded %d importance matrix entries from %s computed on %d chunks\n", __func__, int(imatrix_data.size()), imatrix_file.c_str(), m_last_chunk);
gguf_free(ctx_gguf);
ggml_free(ctx);
return m_last_chunk;
}
static int prepare_imatrix(const std::string & imatrix_file,

View File

@ -26,7 +26,11 @@ void ggml_cuda_op_mul_mat_q(
// nrows_dst == nrows of the matrix that the kernel writes into
const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst};
// The stream-k decomposition is only faster for recent NVIDIA GPUs.
// Also its fixup needs to allocate a temporary buffer in the memory pool.
// There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
const bool use_stream_k = compute_capability >= CC_VOLTA && compute_capability < CC_OFFSET_AMD && src1_ncols == ne11;
const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k};
switch (src0->type) {
case GGML_TYPE_Q4_0:

View File

@ -2742,6 +2742,7 @@ struct mmq_args {
int64_t ne00; int64_t ne01; int64_t stride01;
int64_t ne10; int64_t ne11; int64_t stride11;
int64_t ne0;
bool use_stream_k;
};
template<ggml_type type>
@ -2777,8 +2778,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
const int ntx = (args.ne11 + mmq_x - 1) / mmq_x;
const dim3 block_nums_xy_tiling(nty, ntx, 1);
const bool use_stream_k = cc >= CC_VOLTA && cc < CC_OFFSET_AMD;
if (!use_stream_k) {
if (!args.use_stream_k) {
if (args.ne01 % mmq_y == 0) {
constexpr bool need_check = false;
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, shmem, stream>>>

View File

@ -174,6 +174,12 @@ class Keys:
TYPE = "adapter.type"
LORA_ALPHA = "adapter.lora.alpha"
class IMatrix:
CHUNK_COUNT = "imatrix.chunk_count"
CHUNK_SIZE = "imatrix.chunk_size"
DATASET = "imatrix.dataset"
#
# recommended mapping of model tensor names for storage in gguf
#
@ -182,6 +188,7 @@ class Keys:
class GGUFType:
MODEL = "model"
ADAPTER = "adapter"
IMATRIX = "imatrix"
class MODEL_ARCH(IntEnum):

View File

@ -8,5 +8,6 @@
-r ./requirements/requirements-convert_hf_to_gguf.txt
-r ./requirements/requirements-convert_hf_to_gguf_update.txt
-r ./requirements/requirements-convert_legacy_imatrix_to_gguf.txt
-r ./requirements/requirements-convert_llama_ggml_to_gguf.txt
-r ./requirements/requirements-convert_lora_to_gguf.txt

View File

@ -0,0 +1 @@
-r ./requirements-convert_legacy_llama.txt