mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-13 14:29:52 +00:00
Compare commits
11 Commits
d596105b22
...
6e850f8f6d
Author | SHA1 | Date | |
---|---|---|---|
|
6e850f8f6d | ||
|
dea5e86051 | ||
|
1329c0a75e | ||
|
61408e7fad | ||
|
b9e02e8184 | ||
|
6763f713bb | ||
|
79a2bc042d | ||
|
fc83a9e584 | ||
|
9a465199a1 | ||
|
640039106f | ||
|
d32c74d1f2 |
@ -230,7 +230,7 @@ def get_base_tensor_name(lora_tensor_name: str) -> str:
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert a huggingface PEFT LoRA adapter to a GGML compatible file")
|
||||
description="Convert a Hugging Face PEFT LoRA adapter to a GGUF file")
|
||||
parser.add_argument(
|
||||
"--outfile", type=Path,
|
||||
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
|
||||
@ -257,11 +257,11 @@ def parse_args() -> argparse.Namespace:
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base", type=Path, required=True,
|
||||
help="directory containing base model file",
|
||||
help="directory containing Hugging Face model config files (config.json, tokenizer.json) for the base model that the adapter is based on - only config is needed, actual model weights are not required",
|
||||
)
|
||||
parser.add_argument(
|
||||
"lora_path", type=Path,
|
||||
help="directory containing LoRA adapter file",
|
||||
help="directory containing Hugging Face PEFT LoRA config (adapter_model.json) and weights (adapter_model.safetensors or adapter_model.bin)",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
@ -840,6 +840,8 @@ class OutputFile:
|
||||
self.gguf.add_base_model_version(key, base_model_entry["version"])
|
||||
if "organization" in base_model_entry:
|
||||
self.gguf.add_base_model_organization(key, base_model_entry["organization"])
|
||||
if "description" in base_model_entry:
|
||||
self.gguf.add_base_model_description(key, base_model_entry["description"])
|
||||
if "url" in base_model_entry:
|
||||
self.gguf.add_base_model_url(key, base_model_entry["url"])
|
||||
if "doi" in base_model_entry:
|
||||
@ -849,12 +851,32 @@ class OutputFile:
|
||||
if "repo_url" in base_model_entry:
|
||||
self.gguf.add_base_model_repo_url(key, base_model_entry["repo_url"])
|
||||
|
||||
if metadata.datasets is not None:
|
||||
self.gguf.add_dataset_count(len(metadata.datasets))
|
||||
for key, dataset_entry in enumerate(metadata.datasets):
|
||||
if "name" in dataset_entry:
|
||||
self.gguf.add_dataset_name(key, dataset_entry["name"])
|
||||
if "author" in dataset_entry:
|
||||
self.gguf.add_dataset_author(key, dataset_entry["author"])
|
||||
if "version" in dataset_entry:
|
||||
self.gguf.add_dataset_version(key, dataset_entry["version"])
|
||||
if "organization" in dataset_entry:
|
||||
self.gguf.add_dataset_organization(key, dataset_entry["organization"])
|
||||
if "description" in dataset_entry:
|
||||
self.gguf.add_dataset_description(key, dataset_entry["description"])
|
||||
if "url" in dataset_entry:
|
||||
self.gguf.add_dataset_url(key, dataset_entry["url"])
|
||||
if "doi" in dataset_entry:
|
||||
self.gguf.add_dataset_doi(key, dataset_entry["doi"])
|
||||
if "uuid" in dataset_entry:
|
||||
self.gguf.add_dataset_uuid(key, dataset_entry["uuid"])
|
||||
if "repo_url" in dataset_entry:
|
||||
self.gguf.add_dataset_repo_url(key, dataset_entry["repo_url"])
|
||||
|
||||
if metadata.tags is not None:
|
||||
self.gguf.add_tags(metadata.tags)
|
||||
if metadata.languages is not None:
|
||||
self.gguf.add_languages(metadata.languages)
|
||||
if metadata.datasets is not None:
|
||||
self.gguf.add_datasets(metadata.datasets)
|
||||
|
||||
def add_meta_arch(self, params: Params) -> None:
|
||||
# Metadata About The Neural Architecture Itself
|
||||
|
@ -333,6 +333,15 @@ These options help improve the performance and memory usage of the LLaMA models.
|
||||
|
||||
For information about 4-bit quantization, which can significantly improve performance and reduce memory usage, please refer to llama.cpp's primary [README](../../README.md#prepare-and-quantize).
|
||||
|
||||
## LoRA (Low-Rank Adaptation) adapters
|
||||
|
||||
- `--lora FNAME`: Optional path to a LoRA adapter to use with scaling of 1.0. Can be mixed with `--lora-scaled` and can be repeated to use multiple adapters.
|
||||
- `--lora-scaled FNAME`: Optional path to a LoRA adapter with user-defined scaling. Can be mixed with `--lora` and can repeated to use multiple adapters.
|
||||
|
||||
You can add LoRA adapters using `--lora` or `--lora-scaled`. For example: `--lora my_adapter_1.gguf --lora my_adapter_2.gguf ...` or `--lora-scaled lora_task_A.gguf 0.5 --lora-scaled lora_task_B.gguf 0.5`.
|
||||
|
||||
LoRA adapters should be in GGUF format. To convert from Hugging Face format use the `convert-lora-to-gguf.py` script. LoRA adapters are loaded separately and applied during inference - they are not merged with the main model. This means that mmap model loading is fully supported when using LoRA adapters. The old `--lora-base` flag has been removed now that merging is no longer performed.
|
||||
|
||||
## Additional Options
|
||||
|
||||
These options provide extra functionality and customization when running the LLaMA models:
|
||||
@ -341,6 +350,4 @@ These options provide extra functionality and customization when running the LLa
|
||||
- `--verbose-prompt`: Print the prompt before generating text.
|
||||
- `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used.
|
||||
- `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance.
|
||||
- `--lora FNAME`: Apply a LoRA (Low-Rank Adaptation) adapter to the model (implies --no-mmap). This allows you to adapt the pretrained model to specific tasks or domains.
|
||||
- `--lora-base FNAME`: Optional model to use as a base for the layers modified by the LoRA adapter. This flag is used in conjunction with the `--lora` flag, and specifies the base model for the adaptation.
|
||||
- `-hfr URL --hf-repo URL`: The url to the Hugging Face model repository. Used in conjunction with `--hf-file` or `-hff`. The model is downloaded and stored in the file provided by `-m` or `--model`. If `-m` is not provided, the model is auto-stored in the path specified by the `LLAMA_CACHE` environment variable or in an OS-specific local cache.
|
||||
|
@ -11,6 +11,8 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define GGML_KOMPUTE_MAX_DEVICES 16
|
||||
|
||||
struct ggml_vk_device {
|
||||
int index;
|
||||
int type; // same as VkPhysicalDeviceType
|
||||
@ -41,6 +43,8 @@ GGML_API bool ggml_backend_is_kompute(ggml_backend_t backend);
|
||||
|
||||
GGML_API ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device);
|
||||
|
||||
GGML_API ggml_backend_reg_t ggml_backend_kompute_reg(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
@ -800,6 +800,7 @@ if (GGML_KOMPUTE)
|
||||
kompute-shaders/op_mul_mat_q8_0.comp
|
||||
kompute-shaders/op_mul_mat_q4_0.comp
|
||||
kompute-shaders/op_mul_mat_q4_1.comp
|
||||
kompute-shaders/op_mul_mat_q4_k.comp
|
||||
kompute-shaders/op_mul_mat_q6_k.comp
|
||||
kompute-shaders/op_getrows_f32.comp
|
||||
kompute-shaders/op_getrows_f16.comp
|
||||
@ -833,6 +834,7 @@ if (GGML_KOMPUTE)
|
||||
shaderop_mul_mat_q8_0.h
|
||||
shaderop_mul_mat_q4_0.h
|
||||
shaderop_mul_mat_q4_1.h
|
||||
shaderop_mul_mat_q4_k.h
|
||||
shaderop_mul_mat_q6_k.h
|
||||
shaderop_getrows_f32.h
|
||||
shaderop_getrows_f16.h
|
||||
|
@ -991,6 +991,73 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
||||
}
|
||||
}
|
||||
return;
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
if (__riscv_vlenb() >= QK4_0) {
|
||||
const size_t vl = QK4_0;
|
||||
|
||||
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
|
||||
|
||||
vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
||||
for (int l = 0; l < nb; l++) {
|
||||
const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0];
|
||||
const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[8];
|
||||
const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[16];
|
||||
const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[24];
|
||||
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
|
||||
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, vl / 4));
|
||||
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, vl / 4));
|
||||
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, vl / 4));
|
||||
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, vl / 4));
|
||||
|
||||
const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4);
|
||||
const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4);
|
||||
const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4);
|
||||
const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0);
|
||||
const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1);
|
||||
const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0);
|
||||
const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1);
|
||||
|
||||
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
|
||||
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
|
||||
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
|
||||
const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
|
||||
|
||||
const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_hi_m));
|
||||
const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
|
||||
const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
|
||||
const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
|
||||
const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
|
||||
const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
|
||||
const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
|
||||
const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
|
||||
const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
|
||||
const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
|
||||
const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
|
||||
const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
|
||||
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
|
||||
|
||||
// vector version needs Zvfhmin extension
|
||||
const float a_scale = GGML_FP16_TO_FP32(a_ptr[l].d);
|
||||
const float b_scales[8] = {
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[0]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[1]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[2]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[3]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[4]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[5]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[6]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[7])
|
||||
};
|
||||
const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4);
|
||||
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scale, vl / 4);
|
||||
sumf = __riscv_vfmacc_vv_f32m1(sumf, tmp1, b_scales_vec, vl / 4);
|
||||
}
|
||||
__riscv_vse32_v_f32m1(s + x * ncols_interleaved, sumf, vl / 4);
|
||||
}
|
||||
return;
|
||||
}
|
||||
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
|
||||
{
|
||||
float sumf[8];
|
||||
@ -3171,6 +3238,207 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
if (__riscv_vlenb() >= QK4_0) {
|
||||
const size_t vl = QK4_0;
|
||||
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
|
||||
vfloat32m1_t sumf0 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
||||
vfloat32m1_t sumf1 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
||||
vfloat32m1_t sumf2 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
||||
vfloat32m1_t sumf3 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
||||
for (int l = 0; l < nb; l++) {
|
||||
const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4);
|
||||
const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4);
|
||||
const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4);
|
||||
const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0);
|
||||
const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1);
|
||||
const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0);
|
||||
const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1);
|
||||
|
||||
// vector version needs Zvfhmin extension
|
||||
const float a_scales[4] = {
|
||||
GGML_FP16_TO_FP32(a_ptr[l].d[0]),
|
||||
GGML_FP16_TO_FP32(a_ptr[l].d[1]),
|
||||
GGML_FP16_TO_FP32(a_ptr[l].d[2]),
|
||||
GGML_FP16_TO_FP32(a_ptr[l].d[3])
|
||||
};
|
||||
const float b_scales[8] = {
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[0]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[1]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[2]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[3]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[4]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[5]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[6]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[7])
|
||||
};
|
||||
const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4);
|
||||
|
||||
const int64_t A0 = *(const int64_t *)&a_ptr[l].qs[0];
|
||||
const int64_t A4 = *(const int64_t *)&a_ptr[l].qs[32];
|
||||
const int64_t A8 = *(const int64_t *)&a_ptr[l].qs[64];
|
||||
const int64_t Ac = *(const int64_t *)&a_ptr[l].qs[96];
|
||||
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
|
||||
vint16m4_t sumi_l0;
|
||||
{
|
||||
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A0, vl / 4));
|
||||
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A4, vl / 4));
|
||||
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A8, vl / 4));
|
||||
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ac, vl / 4));
|
||||
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
|
||||
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
|
||||
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
|
||||
const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
|
||||
|
||||
sumi_l0 = sumi_hi_m;
|
||||
}
|
||||
|
||||
{
|
||||
const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l0));
|
||||
const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
|
||||
const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
|
||||
const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
|
||||
const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
|
||||
const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
|
||||
const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
|
||||
const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
|
||||
const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
|
||||
const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
|
||||
const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
|
||||
const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
|
||||
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
|
||||
|
||||
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[0], vl / 4);
|
||||
sumf0 = __riscv_vfmacc_vv_f32m1(sumf0, tmp1, b_scales_vec, vl / 4);
|
||||
}
|
||||
|
||||
const int64_t A1 = *(const int64_t *)&a_ptr[l].qs[8];
|
||||
const int64_t A5 = *(const int64_t *)&a_ptr[l].qs[40];
|
||||
const int64_t A9 = *(const int64_t *)&a_ptr[l].qs[72];
|
||||
const int64_t Ad = *(const int64_t *)&a_ptr[l].qs[104];
|
||||
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
|
||||
vint16m4_t sumi_l1;
|
||||
{
|
||||
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A1, vl / 4));
|
||||
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A5, vl / 4));
|
||||
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A9, vl / 4));
|
||||
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ad, vl / 4));
|
||||
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
|
||||
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
|
||||
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
|
||||
const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
|
||||
|
||||
sumi_l1 = sumi_hi_m;
|
||||
}
|
||||
|
||||
{
|
||||
const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l1));
|
||||
const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
|
||||
const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
|
||||
const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
|
||||
const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
|
||||
const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
|
||||
const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
|
||||
const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
|
||||
const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
|
||||
const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
|
||||
const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
|
||||
const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
|
||||
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
|
||||
|
||||
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[1], vl / 4);
|
||||
sumf1 = __riscv_vfmacc_vv_f32m1(sumf1, tmp1, b_scales_vec, vl / 4);
|
||||
}
|
||||
|
||||
const int64_t A2 = *(const int64_t *)&a_ptr[l].qs[16];
|
||||
const int64_t A6 = *(const int64_t *)&a_ptr[l].qs[48];
|
||||
const int64_t Aa = *(const int64_t *)&a_ptr[l].qs[80];
|
||||
const int64_t Ae = *(const int64_t *)&a_ptr[l].qs[112];
|
||||
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
|
||||
vint16m4_t sumi_l2;
|
||||
{
|
||||
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A2, vl / 4));
|
||||
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A6, vl / 4));
|
||||
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Aa, vl / 4));
|
||||
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ae, vl / 4));
|
||||
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
|
||||
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
|
||||
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
|
||||
const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
|
||||
|
||||
sumi_l2 = sumi_hi_m;
|
||||
}
|
||||
|
||||
{
|
||||
const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l2));
|
||||
const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
|
||||
const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
|
||||
const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
|
||||
const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
|
||||
const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
|
||||
const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
|
||||
const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
|
||||
const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
|
||||
const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
|
||||
const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
|
||||
const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
|
||||
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
|
||||
|
||||
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[2], vl / 4);
|
||||
sumf2 = __riscv_vfmacc_vv_f32m1(sumf2, tmp1, b_scales_vec, vl / 4);
|
||||
}
|
||||
|
||||
const int64_t A3 = *(const int64_t *)&a_ptr[l].qs[24];
|
||||
const int64_t A7 = *(const int64_t *)&a_ptr[l].qs[56];
|
||||
const int64_t Ab = *(const int64_t *)&a_ptr[l].qs[88];
|
||||
const int64_t Af = *(const int64_t *)&a_ptr[l].qs[120];
|
||||
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
|
||||
vint16m4_t sumi_l3;
|
||||
{
|
||||
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A3, vl / 4));
|
||||
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A7, vl / 4));
|
||||
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ab, vl / 4));
|
||||
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Af, vl / 4));
|
||||
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
|
||||
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
|
||||
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
|
||||
const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
|
||||
|
||||
sumi_l3 = sumi_hi_m;
|
||||
}
|
||||
|
||||
{
|
||||
const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l3));
|
||||
const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
|
||||
const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
|
||||
const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
|
||||
const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
|
||||
const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
|
||||
const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
|
||||
const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
|
||||
const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
|
||||
const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
|
||||
const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
|
||||
const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
|
||||
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
|
||||
|
||||
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[3], vl / 4);
|
||||
sumf3 = __riscv_vfmacc_vv_f32m1(sumf3, tmp1, b_scales_vec, vl / 4);
|
||||
}
|
||||
}
|
||||
__riscv_vse32_v_f32m1(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, vl / 4);
|
||||
__riscv_vse32_v_f32m1(&s[(y * 4 + 1) * bs + x * ncols_interleaved], sumf1, vl / 4);
|
||||
__riscv_vse32_v_f32m1(&s[(y * 4 + 2) * bs + x * ncols_interleaved], sumf2, vl / 4);
|
||||
__riscv_vse32_v_f32m1(&s[(y * 4 + 3) * bs + x * ncols_interleaved], sumf3, vl / 4);
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
|
||||
|
@ -562,6 +562,10 @@ void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * na
|
||||
#include "ggml-cann.h"
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_KOMPUTE
|
||||
#include "ggml-kompute.h"
|
||||
#endif
|
||||
|
||||
struct ggml_backend_registry {
|
||||
std::vector<ggml_backend_reg_t> backends;
|
||||
std::vector<ggml_backend_dev_t> devices;
|
||||
@ -591,8 +595,9 @@ struct ggml_backend_registry {
|
||||
#ifdef GGML_USE_AMX
|
||||
register_backend(ggml_backend_amx_reg());
|
||||
#endif
|
||||
|
||||
// TODO: kompute
|
||||
#ifdef GGML_USE_KOMPUTE
|
||||
register_backend(ggml_backend_kompute_reg());
|
||||
#endif
|
||||
|
||||
register_backend(ggml_backend_cpu_reg());
|
||||
}
|
||||
|
@ -20,6 +20,7 @@
|
||||
#include "shaderop_mul_mat_q8_0.h"
|
||||
#include "shaderop_mul_mat_q4_0.h"
|
||||
#include "shaderop_mul_mat_q4_1.h"
|
||||
#include "shaderop_mul_mat_q4_k.h"
|
||||
#include "shaderop_mul_mat_q6_k.h"
|
||||
#include "shaderop_mul_mat_mat_f32.h"
|
||||
#include "shaderop_getrows_f32.h"
|
||||
@ -42,6 +43,7 @@
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
@ -273,18 +275,9 @@ static std::vector<ggml_vk_device> ggml_vk_available_devices_internal(size_t mem
|
||||
return results;
|
||||
}
|
||||
|
||||
// public API returns a C-style array
|
||||
ggml_vk_device * ggml_vk_available_devices(size_t memoryRequired, size_t * count) {
|
||||
auto devices = ggml_vk_available_devices_internal(memoryRequired);
|
||||
*count = devices.size();
|
||||
if (devices.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
size_t nbytes = sizeof (ggml_vk_device) * (devices.size());
|
||||
auto * arr = static_cast<ggml_vk_device *>(malloc(nbytes));
|
||||
memcpy(arr, devices.data(), nbytes);
|
||||
return arr;
|
||||
static std::vector<ggml_vk_device>& ggml_vk_available_devices() {
|
||||
static std::vector<ggml_vk_device> devices = ggml_vk_available_devices_internal(0);
|
||||
return devices;
|
||||
}
|
||||
|
||||
static void ggml_vk_filterByVendor(std::vector<ggml_vk_device>& devices, const std::string& targetVendor) {
|
||||
@ -341,7 +334,7 @@ ggml_vk_device ggml_vk_current_device() {
|
||||
if (!komputeManager()->hasDevice())
|
||||
return ggml_vk_device();
|
||||
|
||||
auto devices = ggml_vk_available_devices_internal(0);
|
||||
auto devices = ggml_vk_available_devices();
|
||||
ggml_vk_filterByName(devices, komputeManager()->physicalDevice()->getProperties().deviceName.data());
|
||||
GGML_ASSERT(!devices.empty());
|
||||
return devices.front();
|
||||
@ -1075,6 +1068,40 @@ static void ggml_vk_mul_mat_q8_0(Args&&... args) {
|
||||
ggml_vk_mul_mat_impl(spirv, "q8_0", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
static void ggml_vk_mul_mat_q4_k(
|
||||
kp::Sequence& seq,
|
||||
const std::shared_ptr<kp::Tensor>& inA,
|
||||
const std::shared_ptr<kp::Tensor>& inB,
|
||||
const std::shared_ptr<kp::Tensor>& out,
|
||||
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
|
||||
int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne10,
|
||||
int32_t ne11, int32_t ne12, int32_t ne13, int32_t ne0,
|
||||
int32_t ne1, int32_t r2, int32_t r3
|
||||
) {
|
||||
const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_k_comp_spv,
|
||||
kp::shader_data::op_mul_mat_q4_k_comp_spv_len);
|
||||
|
||||
struct PushConstants {
|
||||
uint32_t inAOff, inBOff, outOff;
|
||||
int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3;
|
||||
} pushConsts {
|
||||
0, 0, 0,
|
||||
ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3
|
||||
};
|
||||
|
||||
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
|
||||
if (!komputeManager()->hasAlgorithm(__func__)) {
|
||||
s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)}, {}, {pushConsts});
|
||||
} else {
|
||||
s_algo = komputeManager()->getAlgorithm(__func__);
|
||||
s_algo->setTensors({inA, inB, out});
|
||||
s_algo->setWorkgroup({unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)});
|
||||
s_algo->setPushConstants<PushConstants>({pushConsts});
|
||||
s_algo->updateDescriptors(s_kompute_context->pool.get());
|
||||
}
|
||||
seq.record<kp::OpAlgoDispatch>(s_algo);
|
||||
}
|
||||
|
||||
static void ggml_vk_mul_mat_q6_k(
|
||||
kp::Sequence& seq,
|
||||
const std::shared_ptr<kp::Tensor>& inA,
|
||||
@ -1323,17 +1350,7 @@ static void ggml_vk_cpy_f16_f32(Args&&... args) {
|
||||
ggml_vk_cpy(spirv, 2, 4, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
|
||||
switch (op->type) {
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
|
||||
switch (op->op) {
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(op)) {
|
||||
@ -1402,6 +1419,7 @@ static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_K:
|
||||
return true;
|
||||
default:
|
||||
;
|
||||
@ -1410,6 +1428,8 @@ static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
|
||||
;
|
||||
}
|
||||
return false;
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph * gf) {
|
||||
@ -1458,11 +1478,6 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
||||
|
||||
any_commands_recorded = true;
|
||||
|
||||
if (!ggml_vk_supports_op(dst)) {
|
||||
fprintf(stderr, "%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
|
||||
GGML_ABORT("unsupported op");
|
||||
}
|
||||
|
||||
const int32_t ne00 = src0 ? src0->ne[0] : 0;
|
||||
const int32_t ne01 = src0 ? src0->ne[1] : 0;
|
||||
const int32_t ne02 = src0 ? src0->ne[2] : 0;
|
||||
@ -1656,6 +1671,12 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
||||
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
|
||||
);
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
ggml_vk_mul_mat_q4_k(
|
||||
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
|
||||
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, ne12/ne02, ne13/ne03
|
||||
);
|
||||
break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
ggml_vk_mul_mat_q6_k(
|
||||
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
|
||||
@ -1907,25 +1928,31 @@ static ggml_backend_buffer_type_i ggml_backend_kompute_buffer_type_interface = {
|
||||
};
|
||||
|
||||
ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device) {
|
||||
static std::vector<ggml_backend_buffer_type> bufts = []() {
|
||||
std::vector<ggml_backend_buffer_type> vec;
|
||||
auto devices = ggml_vk_available_devices_internal(0);
|
||||
vec.reserve(devices.size());
|
||||
static std::mutex mutex;
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
|
||||
for (const auto & dev : devices) {
|
||||
vec.push_back({
|
||||
/* .iface = */ ggml_backend_kompute_buffer_type_interface,
|
||||
/* .device = */ nullptr,
|
||||
/* .context = */ new ggml_backend_kompute_buffer_type_context(dev.index, dev.bufferAlignment, dev.maxAlloc)
|
||||
});
|
||||
auto devices = ggml_vk_available_devices();
|
||||
int32_t device_count = (int32_t) devices.size();
|
||||
GGML_ASSERT(device < device_count);
|
||||
GGML_ASSERT(devices.size() <= GGML_KOMPUTE_MAX_DEVICES);
|
||||
|
||||
static ggml_backend_buffer_type
|
||||
ggml_backend_kompute_buffer_types[GGML_KOMPUTE_MAX_DEVICES];
|
||||
|
||||
static bool ggml_backend_kompute_buffer_type_initialized = false;
|
||||
|
||||
if (!ggml_backend_kompute_buffer_type_initialized) {
|
||||
for (int32_t i = 0; i < device_count; i++) {
|
||||
ggml_backend_kompute_buffer_types[i] = {
|
||||
/* .iface = */ ggml_backend_kompute_buffer_type_interface,
|
||||
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_kompute_reg(), i),
|
||||
/* .context = */ new ggml_backend_kompute_buffer_type_context{ i, devices[i].bufferAlignment, devices[i].maxAlloc },
|
||||
};
|
||||
}
|
||||
return vec;
|
||||
}();
|
||||
ggml_backend_kompute_buffer_type_initialized = true;
|
||||
}
|
||||
|
||||
auto it = std::find_if(bufts.begin(), bufts.end(), [device](const ggml_backend_buffer_type & t) {
|
||||
return device == static_cast<ggml_backend_kompute_buffer_type_context *>(t.context)->device;
|
||||
});
|
||||
return it < bufts.end() ? &*it : nullptr;
|
||||
return &ggml_backend_kompute_buffer_types[device];
|
||||
}
|
||||
|
||||
// backend
|
||||
@ -1953,16 +1980,6 @@ static ggml_status ggml_backend_kompute_graph_compute(ggml_backend_t backend, st
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
static bool ggml_backend_kompute_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
||||
GGML_UNUSED(backend);
|
||||
return ggml_vk_supports_op(op);
|
||||
}
|
||||
|
||||
static bool ggml_backend_kompute_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
|
||||
GGML_UNUSED(backend);
|
||||
return buft->iface.get_name == ggml_backend_kompute_buffer_type_get_name;
|
||||
}
|
||||
|
||||
static struct ggml_backend_i kompute_backend_i = {
|
||||
/* .get_name = */ ggml_backend_kompute_name,
|
||||
/* .free = */ ggml_backend_kompute_free,
|
||||
@ -1991,7 +2008,7 @@ ggml_backend_t ggml_backend_kompute_init(int device) {
|
||||
ggml_backend_t kompute_backend = new ggml_backend {
|
||||
/* .guid = */ ggml_backend_kompute_guid(),
|
||||
/* .interface = */ kompute_backend_i,
|
||||
/* .device = */ nullptr,
|
||||
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_kompute_reg(), device),
|
||||
/* .context = */ s_kompute_context,
|
||||
};
|
||||
|
||||
@ -2001,3 +2018,167 @@ ggml_backend_t ggml_backend_kompute_init(int device) {
|
||||
bool ggml_backend_is_kompute(ggml_backend_t backend) {
|
||||
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_kompute_guid());
|
||||
}
|
||||
|
||||
static size_t ggml_backend_kompute_get_device_count() {
|
||||
auto devices = ggml_vk_available_devices();
|
||||
return devices.size();
|
||||
}
|
||||
|
||||
static void ggml_backend_kompute_get_device_description(int device, char * description, size_t description_size) {
|
||||
auto devices = ggml_vk_available_devices();
|
||||
GGML_ASSERT((size_t) device < devices.size());
|
||||
snprintf(description, description_size, "%s", devices[device].name);
|
||||
}
|
||||
|
||||
static void ggml_backend_kompute_get_device_memory(int device, size_t * free, size_t * total) {
|
||||
auto devices = ggml_vk_available_devices();
|
||||
GGML_ASSERT((size_t) device < devices.size());
|
||||
*total = devices[device].heapSize;
|
||||
*free = devices[device].heapSize;
|
||||
}
|
||||
|
||||
//////////////////////////
|
||||
|
||||
struct ggml_backend_kompute_device_context {
|
||||
int device;
|
||||
std::string name;
|
||||
std::string description;
|
||||
};
|
||||
|
||||
static const char * ggml_backend_kompute_device_get_name(ggml_backend_dev_t dev) {
|
||||
ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
|
||||
return ctx->name.c_str();
|
||||
}
|
||||
|
||||
static const char * ggml_backend_kompute_device_get_description(ggml_backend_dev_t dev) {
|
||||
ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
|
||||
return ctx->description.c_str();
|
||||
}
|
||||
|
||||
static void ggml_backend_kompute_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
||||
ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
|
||||
ggml_backend_kompute_get_device_memory(ctx->device, free, total);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_type_t ggml_backend_kompute_device_get_buffer_type(ggml_backend_dev_t dev) {
|
||||
ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
|
||||
return ggml_backend_kompute_buffer_type(ctx->device);
|
||||
}
|
||||
|
||||
static bool ggml_backend_kompute_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
||||
if (buft->iface.get_name != ggml_backend_kompute_buffer_type_get_name) {
|
||||
return false;
|
||||
}
|
||||
|
||||
ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
|
||||
ggml_backend_kompute_buffer_type_context * buft_ctx = (ggml_backend_kompute_buffer_type_context *)buft->context;
|
||||
|
||||
return buft_ctx->device == ctx->device;
|
||||
}
|
||||
|
||||
static enum ggml_backend_dev_type ggml_backend_kompute_device_get_type(ggml_backend_dev_t dev) {
|
||||
GGML_UNUSED(dev);
|
||||
return GGML_BACKEND_DEVICE_TYPE_GPU;
|
||||
}
|
||||
|
||||
static void ggml_backend_kompute_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
|
||||
props->name = ggml_backend_kompute_device_get_name(dev);
|
||||
props->description = ggml_backend_kompute_device_get_description(dev);
|
||||
props->type = ggml_backend_kompute_device_get_type(dev);
|
||||
ggml_backend_kompute_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||
props->caps = {
|
||||
/* async = */ false,
|
||||
/* host_buffer = */ false,
|
||||
/* .buffer_from_host_ptr = */ false,
|
||||
/* events = */ false,
|
||||
};
|
||||
}
|
||||
|
||||
static ggml_backend_t ggml_backend_kompute_device_init(ggml_backend_dev_t dev, const char * params) {
|
||||
GGML_UNUSED(params);
|
||||
ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
|
||||
return ggml_backend_kompute_init(ctx->device);
|
||||
}
|
||||
|
||||
static bool ggml_backend_kompute_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
||||
const int min_batch_size = 32;
|
||||
|
||||
return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
|
||||
(op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static const struct ggml_backend_device_i ggml_backend_kompute_device_i = {
|
||||
/* .get_name = */ ggml_backend_kompute_device_get_name,
|
||||
/* .get_description = */ ggml_backend_kompute_device_get_description,
|
||||
/* .get_memory = */ ggml_backend_kompute_device_get_memory,
|
||||
/* .get_type = */ ggml_backend_kompute_device_get_type,
|
||||
/* .get_props = */ ggml_backend_kompute_device_get_props,
|
||||
/* .init_backend = */ ggml_backend_kompute_device_init,
|
||||
/* .get_buffer_type = */ ggml_backend_kompute_device_get_buffer_type,
|
||||
/* .get_host_buffer_type = */ NULL,
|
||||
/* .buffer_from_host_ptr = */ NULL,
|
||||
/* .supports_op = */ ggml_backend_kompute_device_supports_op,
|
||||
/* .supports_buft = */ ggml_backend_kompute_device_supports_buft,
|
||||
/* .offload_op = */ ggml_backend_kompute_device_offload_op,
|
||||
/* .event_new = */ NULL,
|
||||
/* .event_free = */ NULL,
|
||||
/* .event_synchronize = */ NULL,
|
||||
};
|
||||
|
||||
static const char * ggml_backend_kompute_reg_get_name(ggml_backend_reg_t reg) {
|
||||
GGML_UNUSED(reg);
|
||||
return "Kompute";
|
||||
}
|
||||
|
||||
static size_t ggml_backend_kompute_reg_get_device_count(ggml_backend_reg_t reg) {
|
||||
GGML_UNUSED(reg);
|
||||
return ggml_backend_kompute_get_device_count();
|
||||
}
|
||||
|
||||
static ggml_backend_dev_t ggml_backend_kompute_reg_get_device(ggml_backend_reg_t reg, size_t device) {
|
||||
static std::vector<ggml_backend_dev_t> devices;
|
||||
|
||||
static bool initialized = false;
|
||||
|
||||
{
|
||||
static std::mutex mutex;
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
if (!initialized) {
|
||||
for (size_t i = 0; i < ggml_backend_kompute_get_device_count(); i++) {
|
||||
ggml_backend_kompute_device_context * ctx = new ggml_backend_kompute_device_context;
|
||||
char desc[256];
|
||||
ggml_backend_kompute_get_device_description(i, desc, sizeof(desc));
|
||||
ctx->device = i;
|
||||
ctx->name = "Kompute" + std::to_string(i);
|
||||
ctx->description = desc;
|
||||
devices.push_back(new ggml_backend_device {
|
||||
/* .iface = */ ggml_backend_kompute_device_i,
|
||||
/* .reg = */ reg,
|
||||
/* .context = */ ctx,
|
||||
});
|
||||
}
|
||||
initialized = true;
|
||||
}
|
||||
}
|
||||
|
||||
GGML_ASSERT(device < devices.size());
|
||||
return devices[device];
|
||||
}
|
||||
|
||||
static const struct ggml_backend_reg_i ggml_backend_kompute_reg_i = {
|
||||
/* .get_name = */ ggml_backend_kompute_reg_get_name,
|
||||
/* .get_device_count = */ ggml_backend_kompute_reg_get_device_count,
|
||||
/* .get_device = */ ggml_backend_kompute_reg_get_device,
|
||||
/* .get_proc_address = */ NULL,
|
||||
};
|
||||
|
||||
ggml_backend_reg_t ggml_backend_kompute_reg() {
|
||||
static ggml_backend_reg reg = {
|
||||
/* .iface = */ ggml_backend_kompute_reg_i,
|
||||
/* .context = */ nullptr,
|
||||
};
|
||||
|
||||
return ®
|
||||
}
|
||||
|
112
ggml/src/ggml.c
112
ggml/src/ggml.c
@ -22102,18 +22102,46 @@ static size_t gguf_type_size(enum gguf_type type) {
|
||||
return GGUF_TYPE_SIZE[type];
|
||||
}
|
||||
|
||||
static void gguf_tensor_info_sanitize(struct gguf_tensor_info * info) {
|
||||
GGML_ASSERT(info->n_dims <= GGML_MAX_DIMS);
|
||||
GGML_ASSERT(0 <= info->type && info->type < GGML_TYPE_COUNT);
|
||||
static bool gguf_tensor_info_sanitize(struct gguf_tensor_info * info) {
|
||||
if (info->n_dims > GGML_MAX_DIMS) {
|
||||
fprintf(stderr, "%s: invalid number of dimensions (%" PRIu32 ")\n", __func__, info->n_dims);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (info->type < 0 || info->type >= GGML_TYPE_COUNT) {
|
||||
fprintf(stderr, "%s: invalid type (%d)\n", __func__, info->type);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (strlen(info->name.data) >= GGML_MAX_NAME) {
|
||||
fprintf(stderr, "%s: tensor '%s' name is too long\n", __func__, info->name.data);
|
||||
return false;
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < info->n_dims; ++i) {
|
||||
GGML_ASSERT(info->ne[i] > 0);
|
||||
if (info->ne[i] <= 0) {
|
||||
fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[i]);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// prevent overflow for total number of elements
|
||||
GGML_ASSERT(INT64_MAX/info->ne[1] > info->ne[0]);
|
||||
GGML_ASSERT(INT64_MAX/info->ne[2] > info->ne[0]*info->ne[1]);
|
||||
GGML_ASSERT(INT64_MAX/info->ne[3] > info->ne[0]*info->ne[1]*info->ne[2]);
|
||||
if (INT64_MAX/info->ne[1] <= info->ne[0]) {
|
||||
fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[1]);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (INT64_MAX/info->ne[2] <= info->ne[0]*info->ne[1]) {
|
||||
fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[2]);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (INT64_MAX/info->ne[3] <= info->ne[0]*info->ne[1]*info->ne[2]) {
|
||||
fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[3]);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool gguf_fread_el(FILE * file, void * dst, size_t size, size_t * offset) {
|
||||
@ -22136,7 +22164,11 @@ static bool gguf_fread_str(FILE * file, struct gguf_str * p, size_t * offset) {
|
||||
return false;
|
||||
}
|
||||
|
||||
p->data = GGML_CALLOC(p->n + 1, 1);
|
||||
p->data = calloc(p->n + 1, 1);
|
||||
if (!p->data) {
|
||||
fprintf(stderr, "%s: failed to allocate memory for string of length %" PRIu64 "\n", __func__, p->n);
|
||||
return false;
|
||||
}
|
||||
|
||||
ok = ok && gguf_fread_el(file, p->data, p->n, offset);
|
||||
|
||||
@ -22170,7 +22202,11 @@ static void gguf_free_kv(struct gguf_kv * kv) {
|
||||
}
|
||||
|
||||
struct gguf_context * gguf_init_empty(void) {
|
||||
struct gguf_context * ctx = GGML_CALLOC(1, sizeof(struct gguf_context));
|
||||
struct gguf_context * ctx = calloc(1, sizeof(struct gguf_context));
|
||||
if (!ctx) {
|
||||
fprintf(stderr, "%s: failed to allocate memory for context\n", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
memcpy(ctx->header.magic, GGUF_MAGIC, sizeof(ctx->header.magic));
|
||||
ctx->header.version = GGUF_VERSION;
|
||||
@ -22216,7 +22252,12 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
||||
|
||||
bool ok = true;
|
||||
|
||||
struct gguf_context * ctx = GGML_CALLOC(1, sizeof(struct gguf_context));
|
||||
struct gguf_context * ctx = calloc(1, sizeof(struct gguf_context));
|
||||
if (!ctx) {
|
||||
fprintf(stderr, "%s: failed to allocate memory for context\n", __func__);
|
||||
fclose(file);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// read the header
|
||||
{
|
||||
@ -22255,9 +22296,13 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
||||
{
|
||||
const uint64_t n_kv = ctx->header.n_kv;
|
||||
|
||||
// header.n_kv will hold the actual value of pairs that were successfully read in the loop below
|
||||
ctx->header.n_kv = 0;
|
||||
ctx->kv = GGML_CALLOC(n_kv, sizeof(struct gguf_kv));
|
||||
ctx->kv = calloc(n_kv, sizeof(struct gguf_kv));
|
||||
if (!ctx->kv) {
|
||||
fprintf(stderr, "%s: failed to allocate memory for kv pairs\n", __func__);
|
||||
fclose(file);
|
||||
gguf_free(ctx);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
for (uint64_t i = 0; i < n_kv; ++i) {
|
||||
struct gguf_kv * kv = &ctx->kv[i];
|
||||
@ -22308,7 +22353,13 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
||||
return NULL;
|
||||
}
|
||||
|
||||
kv->value.arr.data = GGML_CALLOC(kv->value.arr.n, gguf_type_size(kv->value.arr.type));
|
||||
kv->value.arr.data = calloc(kv->value.arr.n, gguf_type_size(kv->value.arr.type));
|
||||
if (!kv->value.arr.data) {
|
||||
fprintf(stderr, "%s: failed to allocate memory for array\n", __func__);
|
||||
fclose(file);
|
||||
gguf_free(ctx);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
ok = ok && gguf_fread_el(file, kv->value.arr.data, kv->value.arr.n * gguf_type_size(kv->value.arr.type), &offset);
|
||||
} break;
|
||||
@ -22322,24 +22373,36 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
||||
return NULL;
|
||||
}
|
||||
|
||||
kv->value.arr.data = GGML_CALLOC(kv->value.arr.n, sizeof(struct gguf_str));
|
||||
kv->value.arr.data = calloc(kv->value.arr.n, sizeof(struct gguf_str));
|
||||
if (!kv->value.arr.data) {
|
||||
fprintf(stderr, "%s: failed to allocate memory for array\n", __func__);
|
||||
fclose(file);
|
||||
gguf_free(ctx);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
for (uint64_t j = 0; j < kv->value.arr.n; ++j) {
|
||||
ok = ok && gguf_fread_str(file, &((struct gguf_str *) kv->value.arr.data)[j], &offset);
|
||||
}
|
||||
} break;
|
||||
case GGUF_TYPE_ARRAY:
|
||||
default: GGML_ABORT("invalid type");
|
||||
default:
|
||||
{
|
||||
fprintf(stderr, "%s: invalid array type %d\n", __func__, kv->value.arr.type);
|
||||
ok = false;
|
||||
} break;
|
||||
}
|
||||
} break;
|
||||
default: GGML_ABORT("invalid type");
|
||||
default:
|
||||
{
|
||||
fprintf(stderr, "%s: invalid type %d\n", __func__, kv->type);
|
||||
ok = false;
|
||||
} break;
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
break;
|
||||
}
|
||||
|
||||
ctx->header.n_kv++;
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
@ -22352,7 +22415,13 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
||||
|
||||
// read the tensor infos
|
||||
if (ctx->header.n_tensors > 0) {
|
||||
ctx->infos = GGML_CALLOC(ctx->header.n_tensors, sizeof(struct gguf_tensor_info));
|
||||
ctx->infos = calloc(ctx->header.n_tensors, sizeof(struct gguf_tensor_info));
|
||||
if (!ctx->infos) {
|
||||
fprintf(stderr, "%s: failed to allocate memory for tensor infos\n", __func__);
|
||||
fclose(file);
|
||||
gguf_free(ctx);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
|
||||
struct gguf_tensor_info * info = &ctx->infos[i];
|
||||
@ -22373,8 +22442,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
||||
ok = ok && gguf_fread_el (file, &info->type, sizeof(info->type), &offset);
|
||||
ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset), &offset);
|
||||
|
||||
// TODO: return an error instead of crashing with GGML_ASSERT
|
||||
gguf_tensor_info_sanitize(info);
|
||||
ok = ok && gguf_tensor_info_sanitize(info);
|
||||
|
||||
// make sure there is no duplicated tensor names
|
||||
for (uint64_t j = 0; j < i && ok; ++j) {
|
||||
|
@ -15,6 +15,7 @@
|
||||
#define TWOPI_F 6.283185307179586f
|
||||
|
||||
#define QK_K 256
|
||||
#define K_SCALE_SIZE 12
|
||||
|
||||
#define u8BufToU16(buf, idx) (((uint16_t(buf[idx + 1]) << 8)) | buf[idx])
|
||||
#define u8BufToFloat16(buf, idx) uint16BitsToHalf u8BufToU16(buf, idx)
|
||||
@ -64,6 +65,14 @@ mat4 dequantize_q4_1(const block_q4_1 xb, uint il) {
|
||||
return reg;
|
||||
}
|
||||
|
||||
#define sizeof_block_q4_k 144
|
||||
struct block_q4_k {
|
||||
float16_t d;
|
||||
float16_t dmin;
|
||||
uint8_t scales[K_SCALE_SIZE];
|
||||
uint8_t qs[QK_K/2];
|
||||
};
|
||||
|
||||
#define sizeof_block_q6_k 210
|
||||
struct block_q6_k {
|
||||
uint8_t ql[QK_K/2]; // quants, lower 4 bits
|
||||
|
133
ggml/src/kompute-shaders/op_mul_mat_q4_k.comp
Normal file
133
ggml/src/kompute-shaders/op_mul_mat_q4_k.comp
Normal file
@ -0,0 +1,133 @@
|
||||
#version 450
|
||||
|
||||
#include "common.comp"
|
||||
|
||||
#define N_DST 4
|
||||
#define SIZE_OF_BLOCK sizeof_block_q4_k
|
||||
|
||||
layout(local_size_x = 4) in;
|
||||
layout(local_size_y = 8) in;
|
||||
layout(local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer tensorInA { block_q4_k inA[]; };
|
||||
layout (binding = 1) readonly buffer tensorInB { float inB[]; };
|
||||
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint inAOff;
|
||||
uint inBOff;
|
||||
uint outOff;
|
||||
int ne00;
|
||||
int ne10;
|
||||
int ne0;
|
||||
int ne1;
|
||||
int ne01;
|
||||
int ne02;
|
||||
int ne12;
|
||||
int r2;
|
||||
int r3;
|
||||
} pcs;
|
||||
|
||||
void main() {
|
||||
const uint16_t kmask1 = uint16_t(0x3f3f);
|
||||
const uint16_t kmask2 = uint16_t(0x0f0f);
|
||||
const uint16_t kmask3 = uint16_t(0xc0c0);
|
||||
|
||||
const uint ix = gl_SubgroupInvocationID/8; // 0...3
|
||||
const uint it = gl_SubgroupInvocationID%8; // 0...7
|
||||
const uint iq = it/4; // 0 or 1
|
||||
const uint ir = it%4; // 0...3
|
||||
|
||||
const uint nb = pcs.ne00/QK_K;
|
||||
|
||||
const uint r0 = gl_WorkGroupID.x;
|
||||
const uint r1 = gl_WorkGroupID.y;
|
||||
const uint im = gl_WorkGroupID.z;
|
||||
|
||||
const uint first_row = r0 * N_DST;
|
||||
const uint ib_row = first_row * nb;
|
||||
|
||||
const uint i12 = im%pcs.ne12;
|
||||
const uint i13 = im/pcs.ne12;
|
||||
|
||||
const uint offset0 = (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
|
||||
|
||||
const uint xblk = ib_row + offset0 + pcs.inAOff;
|
||||
const uint y = r1*pcs.ne10 + im*pcs.ne00*pcs.ne1 + pcs.inBOff;
|
||||
|
||||
float yl[16];
|
||||
float yh[16];
|
||||
float sumf[N_DST] = {0.f, 0.f, 0.f, 0.f};
|
||||
float all_sum = 0.f;
|
||||
|
||||
uint y4 = y + ix * QK_K + 64 * iq + 8 * ir;
|
||||
|
||||
for (uint ib = ix; ib < nb; ib += 4) {
|
||||
const uint blk_idx = ib + xblk;
|
||||
|
||||
float sumy[4] = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
yl[i+0] = inB[y4+i+ 0]; sumy[0] += yl[i+0];
|
||||
yl[i+8] = inB[y4+i+ 32]; sumy[1] += yl[i+8];
|
||||
yh[i+0] = inB[y4+i+128]; sumy[2] += yh[i+0];
|
||||
yh[i+8] = inB[y4+i+160]; sumy[3] += yh[i+8];
|
||||
}
|
||||
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
uint row_idx = row * nb;
|
||||
|
||||
uint16_t sc_0 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 0);
|
||||
uint16_t sc_1 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 2);
|
||||
uint16_t sc_2 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 4);
|
||||
uint16_t sc_3 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 6);
|
||||
uint16_t sc_4 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 8);
|
||||
|
||||
uint16_t sc16[4];
|
||||
sc16[0] = sc_0 & kmask1;
|
||||
sc16[1] = sc_2 & kmask1;
|
||||
sc16[2] = ((sc_4 >> 0) & kmask2) | ((sc_0 & kmask3) >> 2);
|
||||
sc16[3] = ((sc_4 >> 4) & kmask2) | ((sc_2 & kmask3) >> 2);
|
||||
|
||||
float acc1[4] = {0.f, 0.f, 0.f, 0.f};
|
||||
float acc2[4] = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int i = 0; i < 8; i += 2) {
|
||||
uint16_t q1 = u8BufToU16(inA[blk_idx + row_idx].qs, 32 * iq + 8 * ir + i);
|
||||
uint16_t q2 = u8BufToU16(inA[blk_idx + row_idx].qs, 64 + 32 * iq + 8 * ir + i);
|
||||
acc1[0] += yl[i+0] * (q1 & 0x000F);
|
||||
acc1[1] += yl[i+1] * (q1 & 0x0F00);
|
||||
acc1[2] += yl[i+8] * (q1 & 0x00F0);
|
||||
acc1[3] += yl[i+9] * (q1 & 0xF000);
|
||||
acc2[0] += yh[i+0] * (q2 & 0x000F);
|
||||
acc2[1] += yh[i+1] * (q2 & 0x0F00);
|
||||
acc2[2] += yh[i+8] * (q2 & 0x00F0);
|
||||
acc2[3] += yh[i+9] * (q2 & 0xF000);
|
||||
}
|
||||
|
||||
uint8_t sc8_0 = uint8_t(sc16[0] & 0xFF);
|
||||
uint8_t sc8_1 = uint8_t(sc16[0] >> 8 );
|
||||
uint8_t sc8_2 = uint8_t(sc16[1] & 0xFF);
|
||||
uint8_t sc8_3 = uint8_t(sc16[1] >> 8 );
|
||||
uint8_t sc8_4 = uint8_t(sc16[2] & 0xFF);
|
||||
uint8_t sc8_5 = uint8_t(sc16[2] >> 8 );
|
||||
uint8_t sc8_6 = uint8_t(sc16[3] & 0xFF);
|
||||
uint8_t sc8_7 = uint8_t(sc16[3] >> 8 );
|
||||
|
||||
float dall = float(inA[blk_idx + row_idx].d);
|
||||
float dmin = float(inA[blk_idx + row_idx].dmin);
|
||||
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8_0 +
|
||||
(acc1[2] + 1.f/256.f * acc1[3]) * sc8_1 * 1.f/16.f +
|
||||
(acc2[0] + 1.f/256.f * acc2[1]) * sc8_4 +
|
||||
(acc2[2] + 1.f/256.f * acc2[3]) * sc8_5 * 1.f/16.f) -
|
||||
dmin * (sumy[0] * sc8_2 + sumy[1] * sc8_3 + sumy[2] * sc8_6 + sumy[3] * sc8_7);
|
||||
}
|
||||
|
||||
y4 += 4 * QK_K;
|
||||
}
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
all_sum = subgroupAdd(sumf[row]);
|
||||
if (subgroupElect()) {
|
||||
out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row + pcs.outOff] = all_sum;
|
||||
}
|
||||
}
|
||||
}
|
@ -64,15 +64,27 @@ class Keys:
|
||||
BASE_MODEL_AUTHOR = "general.base_model.{id}.author"
|
||||
BASE_MODEL_VERSION = "general.base_model.{id}.version"
|
||||
BASE_MODEL_ORGANIZATION = "general.base_model.{id}.organization"
|
||||
BASE_MODEL_DESCRIPTION = "general.base_model.{id}.description"
|
||||
BASE_MODEL_URL = "general.base_model.{id}.url" # Model Website/Paper
|
||||
BASE_MODEL_DOI = "general.base_model.{id}.doi"
|
||||
BASE_MODEL_UUID = "general.base_model.{id}.uuid"
|
||||
BASE_MODEL_REPO_URL = "general.base_model.{id}.repo_url" # Model Source Repository (git/svn/etc...)
|
||||
|
||||
# Dataset Source
|
||||
DATASET_COUNT = "general.dataset.count"
|
||||
DATASET_NAME = "general.dataset.{id}.name"
|
||||
DATASET_AUTHOR = "general.dataset.{id}.author"
|
||||
DATASET_VERSION = "general.dataset.{id}.version"
|
||||
DATASET_ORGANIZATION = "general.dataset.{id}.organization"
|
||||
DATASET_DESCRIPTION = "general.dataset.{id}.description"
|
||||
DATASET_URL = "general.dataset.{id}.url" # Model Website/Paper
|
||||
DATASET_DOI = "general.dataset.{id}.doi"
|
||||
DATASET_UUID = "general.dataset.{id}.uuid"
|
||||
DATASET_REPO_URL = "general.dataset.{id}.repo_url" # Model Source Repository (git/svn/etc...)
|
||||
|
||||
# Array based KV stores
|
||||
TAGS = "general.tags"
|
||||
LANGUAGES = "general.languages"
|
||||
DATASETS = "general.datasets"
|
||||
|
||||
class LLM:
|
||||
VOCAB_SIZE = "{arch}.vocab_size"
|
||||
|
@ -568,6 +568,9 @@ class GGUFWriter:
|
||||
def add_base_model_organization(self, source_id: int, organization: str) -> None:
|
||||
self.add_string(Keys.General.BASE_MODEL_ORGANIZATION.format(id=source_id), organization)
|
||||
|
||||
def add_base_model_description(self, source_id: int, description: str) -> None:
|
||||
self.add_string(Keys.General.BASE_MODEL_DESCRIPTION.format(id=source_id), description)
|
||||
|
||||
def add_base_model_url(self, source_id: int, url: str) -> None:
|
||||
self.add_string(Keys.General.BASE_MODEL_URL.format(id=source_id), url)
|
||||
|
||||
@ -580,15 +583,42 @@ class GGUFWriter:
|
||||
def add_base_model_repo_url(self, source_id: int, repo_url: str) -> None:
|
||||
self.add_string(Keys.General.BASE_MODEL_REPO_URL.format(id=source_id), repo_url)
|
||||
|
||||
def add_dataset_count(self, source_count: int) -> None:
|
||||
self.add_uint32(Keys.General.DATASET_COUNT, source_count)
|
||||
|
||||
def add_dataset_name(self, source_id: int, name: str) -> None:
|
||||
self.add_string(Keys.General.DATASET_NAME.format(id=source_id), name)
|
||||
|
||||
def add_dataset_author(self, source_id: int, author: str) -> None:
|
||||
self.add_string(Keys.General.DATASET_AUTHOR.format(id=source_id), author)
|
||||
|
||||
def add_dataset_version(self, source_id: int, version: str) -> None:
|
||||
self.add_string(Keys.General.DATASET_VERSION.format(id=source_id), version)
|
||||
|
||||
def add_dataset_organization(self, source_id: int, organization: str) -> None:
|
||||
self.add_string(Keys.General.DATASET_ORGANIZATION.format(id=source_id), organization)
|
||||
|
||||
def add_dataset_description(self, source_id: int, description: str) -> None:
|
||||
self.add_string(Keys.General.DATASET_DESCRIPTION.format(id=source_id), description)
|
||||
|
||||
def add_dataset_url(self, source_id: int, url: str) -> None:
|
||||
self.add_string(Keys.General.DATASET_URL.format(id=source_id), url)
|
||||
|
||||
def add_dataset_doi(self, source_id: int, doi: str) -> None:
|
||||
self.add_string(Keys.General.DATASET_DOI.format(id=source_id), doi)
|
||||
|
||||
def add_dataset_uuid(self, source_id: int, uuid: str) -> None:
|
||||
self.add_string(Keys.General.DATASET_UUID.format(id=source_id), uuid)
|
||||
|
||||
def add_dataset_repo_url(self, source_id: int, repo_url: str) -> None:
|
||||
self.add_string(Keys.General.DATASET_REPO_URL.format(id=source_id), repo_url)
|
||||
|
||||
def add_tags(self, tags: Sequence[str]) -> None:
|
||||
self.add_array(Keys.General.TAGS, tags)
|
||||
|
||||
def add_languages(self, languages: Sequence[str]) -> None:
|
||||
self.add_array(Keys.General.LANGUAGES, languages)
|
||||
|
||||
def add_datasets(self, datasets: Sequence[str]) -> None:
|
||||
self.add_array(Keys.General.DATASETS, datasets)
|
||||
|
||||
def add_tensor_data_layout(self, layout: str) -> None:
|
||||
self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
|
||||
|
||||
|
@ -41,7 +41,7 @@ class Metadata:
|
||||
base_models: Optional[list[dict]] = None
|
||||
tags: Optional[list[str]] = None
|
||||
languages: Optional[list[str]] = None
|
||||
datasets: Optional[list[str]] = None
|
||||
datasets: Optional[list[dict]] = None
|
||||
|
||||
@staticmethod
|
||||
def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, total_params: int = 0) -> Metadata:
|
||||
@ -91,9 +91,11 @@ class Metadata:
|
||||
# Base Models is received here as an array of models
|
||||
metadata.base_models = metadata_override.get("general.base_models", metadata.base_models)
|
||||
|
||||
# Datasets is received here as an array of datasets
|
||||
metadata.datasets = metadata_override.get("general.datasets", metadata.datasets)
|
||||
|
||||
metadata.tags = metadata_override.get(Keys.General.TAGS, metadata.tags)
|
||||
metadata.languages = metadata_override.get(Keys.General.LANGUAGES, metadata.languages)
|
||||
metadata.datasets = metadata_override.get(Keys.General.DATASETS, metadata.datasets)
|
||||
|
||||
# Direct Metadata Override (via direct cli argument)
|
||||
if model_name is not None:
|
||||
@ -346,12 +348,12 @@ class Metadata:
|
||||
use_model_card_metadata("author", "model_creator")
|
||||
use_model_card_metadata("basename", "model_type")
|
||||
|
||||
if "base_model" in model_card:
|
||||
if "base_model" in model_card or "base_models" in model_card or "base_model_sources" in model_card:
|
||||
# This represents the parent models that this is based on
|
||||
# Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges)
|
||||
# Example of merges: https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0.1/blob/main/README.md
|
||||
metadata_base_models = []
|
||||
base_model_value = model_card.get("base_model", None)
|
||||
base_model_value = model_card.get("base_model", model_card.get("base_models", model_card.get("base_model_sources", None)))
|
||||
|
||||
if base_model_value is not None:
|
||||
if isinstance(base_model_value, str):
|
||||
@ -364,18 +366,106 @@ class Metadata:
|
||||
|
||||
for model_id in metadata_base_models:
|
||||
# NOTE: model size of base model is assumed to be similar to the size of the current model
|
||||
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
|
||||
base_model = {}
|
||||
if model_full_name_component is not None:
|
||||
base_model["name"] = Metadata.id_to_title(model_full_name_component)
|
||||
if org_component is not None:
|
||||
base_model["organization"] = Metadata.id_to_title(org_component)
|
||||
if version is not None:
|
||||
base_model["version"] = version
|
||||
if org_component is not None and model_full_name_component is not None:
|
||||
base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}"
|
||||
if isinstance(model_id, str):
|
||||
if model_id.startswith("http://") or model_id.startswith("https://") or model_id.startswith("ssh://"):
|
||||
base_model["repo_url"] = model_id
|
||||
|
||||
# Check if Hugging Face ID is present in URL
|
||||
if "huggingface.co" in model_id:
|
||||
match = re.match(r"https?://huggingface.co/([^/]+/[^/]+)$", model_id)
|
||||
if match:
|
||||
model_id_component = match.group(1)
|
||||
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id_component, total_params)
|
||||
|
||||
# Populate model dictionary with extracted components
|
||||
if model_full_name_component is not None:
|
||||
base_model["name"] = Metadata.id_to_title(model_full_name_component)
|
||||
if org_component is not None:
|
||||
base_model["organization"] = Metadata.id_to_title(org_component)
|
||||
if version is not None:
|
||||
base_model["version"] = version
|
||||
|
||||
else:
|
||||
# Likely a Hugging Face ID
|
||||
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
|
||||
|
||||
# Populate model dictionary with extracted components
|
||||
if model_full_name_component is not None:
|
||||
base_model["name"] = Metadata.id_to_title(model_full_name_component)
|
||||
if org_component is not None:
|
||||
base_model["organization"] = Metadata.id_to_title(org_component)
|
||||
if version is not None:
|
||||
base_model["version"] = version
|
||||
if org_component is not None and model_full_name_component is not None:
|
||||
base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}"
|
||||
|
||||
elif isinstance(model_id, dict):
|
||||
base_model = model_id
|
||||
|
||||
else:
|
||||
logger.error(f"base model entry '{str(model_id)}' not in a known format")
|
||||
|
||||
metadata.base_models.append(base_model)
|
||||
|
||||
if "datasets" in model_card or "dataset" in model_card or "dataset_sources" in model_card:
|
||||
# This represents the datasets that this was trained from
|
||||
metadata_datasets = []
|
||||
dataset_value = model_card.get("datasets", model_card.get("dataset", model_card.get("dataset_sources", None)))
|
||||
|
||||
if dataset_value is not None:
|
||||
if isinstance(dataset_value, str):
|
||||
metadata_datasets.append(dataset_value)
|
||||
elif isinstance(dataset_value, list):
|
||||
metadata_datasets.extend(dataset_value)
|
||||
|
||||
if metadata.datasets is None:
|
||||
metadata.datasets = []
|
||||
|
||||
for dataset_id in metadata_datasets:
|
||||
# NOTE: model size of base model is assumed to be similar to the size of the current model
|
||||
dataset = {}
|
||||
if isinstance(dataset_id, str):
|
||||
if dataset_id.startswith(("http://", "https://", "ssh://")):
|
||||
dataset["repo_url"] = dataset_id
|
||||
|
||||
# Check if Hugging Face ID is present in URL
|
||||
if "huggingface.co" in dataset_id:
|
||||
match = re.match(r"https?://huggingface.co/([^/]+/[^/]+)$", dataset_id)
|
||||
if match:
|
||||
dataset_id_component = match.group(1)
|
||||
dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id_component, total_params)
|
||||
|
||||
# Populate dataset dictionary with extracted components
|
||||
if dataset_name_component is not None:
|
||||
dataset["name"] = Metadata.id_to_title(dataset_name_component)
|
||||
if org_component is not None:
|
||||
dataset["organization"] = Metadata.id_to_title(org_component)
|
||||
if version is not None:
|
||||
dataset["version"] = version
|
||||
|
||||
else:
|
||||
# Likely a Hugging Face ID
|
||||
dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id, total_params)
|
||||
|
||||
# Populate dataset dictionary with extracted components
|
||||
if dataset_name_component is not None:
|
||||
dataset["name"] = Metadata.id_to_title(dataset_name_component)
|
||||
if org_component is not None:
|
||||
dataset["organization"] = Metadata.id_to_title(org_component)
|
||||
if version is not None:
|
||||
dataset["version"] = version
|
||||
if org_component is not None and dataset_name_component is not None:
|
||||
dataset["repo_url"] = f"https://huggingface.co/{org_component}/{dataset_name_component}"
|
||||
|
||||
elif isinstance(dataset_id, dict):
|
||||
dataset = dataset_id
|
||||
|
||||
else:
|
||||
logger.error(f"dataset entry '{str(dataset_id)}' not in a known format")
|
||||
|
||||
metadata.datasets.append(dataset)
|
||||
|
||||
use_model_card_metadata("license", "license")
|
||||
use_model_card_metadata("license_name", "license_name")
|
||||
use_model_card_metadata("license_link", "license_link")
|
||||
@ -386,9 +476,6 @@ class Metadata:
|
||||
use_array_model_card_metadata("languages", "languages")
|
||||
use_array_model_card_metadata("languages", "language")
|
||||
|
||||
use_array_model_card_metadata("datasets", "datasets")
|
||||
use_array_model_card_metadata("datasets", "dataset")
|
||||
|
||||
# Hugging Face Parameter Heuristics
|
||||
####################################
|
||||
|
||||
@ -493,6 +580,8 @@ class Metadata:
|
||||
gguf_writer.add_base_model_version(key, base_model_entry["version"])
|
||||
if "organization" in base_model_entry:
|
||||
gguf_writer.add_base_model_organization(key, base_model_entry["organization"])
|
||||
if "description" in base_model_entry:
|
||||
gguf_writer.add_base_model_description(key, base_model_entry["description"])
|
||||
if "url" in base_model_entry:
|
||||
gguf_writer.add_base_model_url(key, base_model_entry["url"])
|
||||
if "doi" in base_model_entry:
|
||||
@ -502,9 +591,29 @@ class Metadata:
|
||||
if "repo_url" in base_model_entry:
|
||||
gguf_writer.add_base_model_repo_url(key, base_model_entry["repo_url"])
|
||||
|
||||
if self.datasets is not None:
|
||||
gguf_writer.add_dataset_count(len(self.datasets))
|
||||
for key, dataset_entry in enumerate(self.datasets):
|
||||
if "name" in dataset_entry:
|
||||
gguf_writer.add_dataset_name(key, dataset_entry["name"])
|
||||
if "author" in dataset_entry:
|
||||
gguf_writer.add_dataset_author(key, dataset_entry["author"])
|
||||
if "version" in dataset_entry:
|
||||
gguf_writer.add_dataset_version(key, dataset_entry["version"])
|
||||
if "organization" in dataset_entry:
|
||||
gguf_writer.add_dataset_organization(key, dataset_entry["organization"])
|
||||
if "description" in dataset_entry:
|
||||
gguf_writer.add_dataset_description(key, dataset_entry["description"])
|
||||
if "url" in dataset_entry:
|
||||
gguf_writer.add_dataset_url(key, dataset_entry["url"])
|
||||
if "doi" in dataset_entry:
|
||||
gguf_writer.add_dataset_doi(key, dataset_entry["doi"])
|
||||
if "uuid" in dataset_entry:
|
||||
gguf_writer.add_dataset_uuid(key, dataset_entry["uuid"])
|
||||
if "repo_url" in dataset_entry:
|
||||
gguf_writer.add_dataset_repo_url(key, dataset_entry["repo_url"])
|
||||
|
||||
if self.tags is not None:
|
||||
gguf_writer.add_tags(self.tags)
|
||||
if self.languages is not None:
|
||||
gguf_writer.add_languages(self.languages)
|
||||
if self.datasets is not None:
|
||||
gguf_writer.add_datasets(self.datasets)
|
||||
|
@ -182,8 +182,43 @@ class TestMetadataMethod(unittest.TestCase):
|
||||
expect.base_models=[{'name': 'Mistral 7B Merge 14 v0', 'organization': 'EmbeddedLLM', 'version': '14-v0', 'repo_url': 'https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0'}, {'name': 'Trinity v1', 'organization': 'Janai Hq', 'version': 'v1', 'repo_url': 'https://huggingface.co/janai-hq/trinity-v1'}]
|
||||
expect.tags=['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl']
|
||||
expect.languages=['en']
|
||||
expect.datasets=['teknium/OpenHermes-2.5']
|
||||
expect.datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]
|
||||
self.assertEqual(got, expect)
|
||||
|
||||
# Base Model spec is inferred from model id
|
||||
model_card = {'base_models': 'teknium/OpenHermes-2.5'}
|
||||
expect = gguf.Metadata(base_models=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
|
||||
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
|
||||
self.assertEqual(got, expect)
|
||||
|
||||
# Base Model spec is only url
|
||||
model_card = {'base_models': ['https://huggingface.co/teknium/OpenHermes-2.5']}
|
||||
expect = gguf.Metadata(base_models=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
|
||||
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
|
||||
self.assertEqual(got, expect)
|
||||
|
||||
# Base Model spec is given directly
|
||||
model_card = {'base_models': [{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]}
|
||||
expect = gguf.Metadata(base_models=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
|
||||
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
|
||||
self.assertEqual(got, expect)
|
||||
|
||||
# Dataset spec is inferred from model id
|
||||
model_card = {'datasets': 'teknium/OpenHermes-2.5'}
|
||||
expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
|
||||
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
|
||||
self.assertEqual(got, expect)
|
||||
|
||||
# Dataset spec is only url
|
||||
model_card = {'datasets': ['https://huggingface.co/teknium/OpenHermes-2.5']}
|
||||
expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
|
||||
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
|
||||
self.assertEqual(got, expect)
|
||||
|
||||
# Dataset spec is given directly
|
||||
model_card = {'datasets': [{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]}
|
||||
expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
|
||||
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
|
||||
self.assertEqual(got, expect)
|
||||
|
||||
def test_apply_metadata_heuristic_from_hf_parameters(self):
|
||||
|
@ -4273,8 +4273,11 @@ struct llama_model_loader {
|
||||
|
||||
llama_tensor_weight(const llama_file * file, uint16_t idx, const char * name, const struct gguf_context * gguf_ctx, ggml_tensor * tensor) : idx(idx), tensor(tensor) {
|
||||
const int tensor_idx = gguf_find_tensor(gguf_ctx, name);
|
||||
offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx);
|
||||
if (tensor_idx < 0) {
|
||||
throw std::runtime_error(format("tensor '%s' not found in the model", name));
|
||||
}
|
||||
|
||||
offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx);
|
||||
if (offs + ggml_nbytes(tensor) < offs || offs + ggml_nbytes(tensor) > file->size) {
|
||||
throw std::runtime_error(format("tensor '%s' data is not within the file bounds, model is corrupted or incomplete", name));
|
||||
}
|
||||
@ -7426,7 +7429,7 @@ static bool llm_load_tensors(
|
||||
if (flags & llama_model_loader::TENSOR_NOT_REQUIRED) {
|
||||
return nullptr;
|
||||
}
|
||||
throw std::runtime_error(format("missing tensor %s", tn.str().c_str()));
|
||||
throw std::runtime_error(format("missing tensor '%s'", tn.str().c_str()));
|
||||
}
|
||||
|
||||
// some models use the token embedding tensor as the output, but since these are used in different layers and with different ops
|
||||
|
Loading…
Reference in New Issue
Block a user