minor style changes

This commit is contained in:
slaren 2024-10-31 19:23:44 +01:00
commit 13eba91a32
7 changed files with 242 additions and 28 deletions

View File

@ -3259,7 +3259,7 @@ int main(int argc, char ** argv) {
ctx_server.queue_tasks.terminate();
};
LOG_INF("%s: server is listening on %s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
LOG_INF("%s: server is listening on http://%s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
ctx_server.queue_tasks.start_loop();

View File

@ -800,6 +800,7 @@ if (GGML_KOMPUTE)
kompute-shaders/op_mul_mat_q8_0.comp
kompute-shaders/op_mul_mat_q4_0.comp
kompute-shaders/op_mul_mat_q4_1.comp
kompute-shaders/op_mul_mat_q4_k.comp
kompute-shaders/op_mul_mat_q6_k.comp
kompute-shaders/op_getrows_f32.comp
kompute-shaders/op_getrows_f16.comp
@ -833,6 +834,7 @@ if (GGML_KOMPUTE)
shaderop_mul_mat_q8_0.h
shaderop_mul_mat_q4_0.h
shaderop_mul_mat_q4_1.h
shaderop_mul_mat_q4_k.h
shaderop_mul_mat_q6_k.h
shaderop_getrows_f32.h
shaderop_getrows_f16.h

View File

@ -20,6 +20,7 @@
#include "shaderop_mul_mat_q8_0.h"
#include "shaderop_mul_mat_q4_0.h"
#include "shaderop_mul_mat_q4_1.h"
#include "shaderop_mul_mat_q4_k.h"
#include "shaderop_mul_mat_q6_k.h"
#include "shaderop_mul_mat_mat_f32.h"
#include "shaderop_getrows_f32.h"
@ -1067,6 +1068,40 @@ static void ggml_vk_mul_mat_q8_0(Args&&... args) {
ggml_vk_mul_mat_impl(spirv, "q8_0", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
}
static void ggml_vk_mul_mat_q4_k(
kp::Sequence& seq,
const std::shared_ptr<kp::Tensor>& inA,
const std::shared_ptr<kp::Tensor>& inB,
const std::shared_ptr<kp::Tensor>& out,
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne10,
int32_t ne11, int32_t ne12, int32_t ne13, int32_t ne0,
int32_t ne1, int32_t r2, int32_t r3
) {
const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_k_comp_spv,
kp::shader_data::op_mul_mat_q4_k_comp_spv_len);
struct PushConstants {
uint32_t inAOff, inBOff, outOff;
int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3;
} pushConsts {
0, 0, 0,
ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3
};
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
if (!komputeManager()->hasAlgorithm(__func__)) {
s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)}, {}, {pushConsts});
} else {
s_algo = komputeManager()->getAlgorithm(__func__);
s_algo->setTensors({inA, inB, out});
s_algo->setWorkgroup({unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)});
s_algo->setPushConstants<PushConstants>({pushConsts});
s_algo->updateDescriptors(s_kompute_context->pool.get());
}
seq.record<kp::OpAlgoDispatch>(s_algo);
}
static void ggml_vk_mul_mat_q6_k(
kp::Sequence& seq,
const std::shared_ptr<kp::Tensor>& inA,
@ -1384,6 +1419,7 @@ static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, cons
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_K:
return true;
default:
;
@ -1635,6 +1671,12 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
);
break;
case GGML_TYPE_Q4_K:
ggml_vk_mul_mat_q4_k(
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, ne12/ne02, ne13/ne03
);
break;
case GGML_TYPE_Q6_K:
ggml_vk_mul_mat_q6_k(
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,

View File

@ -22102,18 +22102,46 @@ static size_t gguf_type_size(enum gguf_type type) {
return GGUF_TYPE_SIZE[type];
}
static void gguf_tensor_info_sanitize(struct gguf_tensor_info * info) {
GGML_ASSERT(info->n_dims <= GGML_MAX_DIMS);
GGML_ASSERT(0 <= info->type && info->type < GGML_TYPE_COUNT);
static bool gguf_tensor_info_sanitize(struct gguf_tensor_info * info) {
if (info->n_dims > GGML_MAX_DIMS) {
fprintf(stderr, "%s: invalid number of dimensions (%" PRIu32 ")\n", __func__, info->n_dims);
return false;
}
if (info->type < 0 || info->type >= GGML_TYPE_COUNT) {
fprintf(stderr, "%s: invalid type (%d)\n", __func__, info->type);
return false;
}
if (strlen(info->name.data) >= GGML_MAX_NAME) {
fprintf(stderr, "%s: tensor '%s' name is too long\n", __func__, info->name.data);
return false;
}
for (uint32_t i = 0; i < info->n_dims; ++i) {
GGML_ASSERT(info->ne[i] > 0);
if (info->ne[i] <= 0) {
fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[i]);
return false;
}
}
// prevent overflow for total number of elements
GGML_ASSERT(INT64_MAX/info->ne[1] > info->ne[0]);
GGML_ASSERT(INT64_MAX/info->ne[2] > info->ne[0]*info->ne[1]);
GGML_ASSERT(INT64_MAX/info->ne[3] > info->ne[0]*info->ne[1]*info->ne[2]);
if (INT64_MAX/info->ne[1] <= info->ne[0]) {
fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[1]);
return false;
}
if (INT64_MAX/info->ne[2] <= info->ne[0]*info->ne[1]) {
fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[2]);
return false;
}
if (INT64_MAX/info->ne[3] <= info->ne[0]*info->ne[1]*info->ne[2]) {
fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[3]);
return false;
}
return true;
}
static bool gguf_fread_el(FILE * file, void * dst, size_t size, size_t * offset) {
@ -22414,8 +22442,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
ok = ok && gguf_fread_el (file, &info->type, sizeof(info->type), &offset);
ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset), &offset);
// TODO: return an error instead of crashing with GGML_ASSERT
gguf_tensor_info_sanitize(info);
ok = ok && gguf_tensor_info_sanitize(info);
// make sure there is no duplicated tensor names
for (uint64_t j = 0; j < i && ok; ++j) {

View File

@ -15,6 +15,7 @@
#define TWOPI_F 6.283185307179586f
#define QK_K 256
#define K_SCALE_SIZE 12
#define u8BufToU16(buf, idx) (((uint16_t(buf[idx + 1]) << 8)) | buf[idx])
#define u8BufToFloat16(buf, idx) uint16BitsToHalf u8BufToU16(buf, idx)
@ -64,6 +65,14 @@ mat4 dequantize_q4_1(const block_q4_1 xb, uint il) {
return reg;
}
#define sizeof_block_q4_k 144
struct block_q4_k {
float16_t d;
float16_t dmin;
uint8_t scales[K_SCALE_SIZE];
uint8_t qs[QK_K/2];
};
#define sizeof_block_q6_k 210
struct block_q6_k {
uint8_t ql[QK_K/2]; // quants, lower 4 bits

View File

@ -0,0 +1,133 @@
#version 450
#include "common.comp"
#define N_DST 4
#define SIZE_OF_BLOCK sizeof_block_q4_k
layout(local_size_x = 4) in;
layout(local_size_y = 8) in;
layout(local_size_z = 1) in;
layout (binding = 0) readonly buffer tensorInA { block_q4_k inA[]; };
layout (binding = 1) readonly buffer tensorInB { float inB[]; };
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
layout (push_constant) uniform parameter {
uint inAOff;
uint inBOff;
uint outOff;
int ne00;
int ne10;
int ne0;
int ne1;
int ne01;
int ne02;
int ne12;
int r2;
int r3;
} pcs;
void main() {
const uint16_t kmask1 = uint16_t(0x3f3f);
const uint16_t kmask2 = uint16_t(0x0f0f);
const uint16_t kmask3 = uint16_t(0xc0c0);
const uint ix = gl_SubgroupInvocationID/8; // 0...3
const uint it = gl_SubgroupInvocationID%8; // 0...7
const uint iq = it/4; // 0 or 1
const uint ir = it%4; // 0...3
const uint nb = pcs.ne00/QK_K;
const uint r0 = gl_WorkGroupID.x;
const uint r1 = gl_WorkGroupID.y;
const uint im = gl_WorkGroupID.z;
const uint first_row = r0 * N_DST;
const uint ib_row = first_row * nb;
const uint i12 = im%pcs.ne12;
const uint i13 = im/pcs.ne12;
const uint offset0 = (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
const uint xblk = ib_row + offset0 + pcs.inAOff;
const uint y = r1*pcs.ne10 + im*pcs.ne00*pcs.ne1 + pcs.inBOff;
float yl[16];
float yh[16];
float sumf[N_DST] = {0.f, 0.f, 0.f, 0.f};
float all_sum = 0.f;
uint y4 = y + ix * QK_K + 64 * iq + 8 * ir;
for (uint ib = ix; ib < nb; ib += 4) {
const uint blk_idx = ib + xblk;
float sumy[4] = {0.f, 0.f, 0.f, 0.f};
for (int i = 0; i < 8; ++i) {
yl[i+0] = inB[y4+i+ 0]; sumy[0] += yl[i+0];
yl[i+8] = inB[y4+i+ 32]; sumy[1] += yl[i+8];
yh[i+0] = inB[y4+i+128]; sumy[2] += yh[i+0];
yh[i+8] = inB[y4+i+160]; sumy[3] += yh[i+8];
}
for (int row = 0; row < N_DST; row++) {
uint row_idx = row * nb;
uint16_t sc_0 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 0);
uint16_t sc_1 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 2);
uint16_t sc_2 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 4);
uint16_t sc_3 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 6);
uint16_t sc_4 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 8);
uint16_t sc16[4];
sc16[0] = sc_0 & kmask1;
sc16[1] = sc_2 & kmask1;
sc16[2] = ((sc_4 >> 0) & kmask2) | ((sc_0 & kmask3) >> 2);
sc16[3] = ((sc_4 >> 4) & kmask2) | ((sc_2 & kmask3) >> 2);
float acc1[4] = {0.f, 0.f, 0.f, 0.f};
float acc2[4] = {0.f, 0.f, 0.f, 0.f};
for (int i = 0; i < 8; i += 2) {
uint16_t q1 = u8BufToU16(inA[blk_idx + row_idx].qs, 32 * iq + 8 * ir + i);
uint16_t q2 = u8BufToU16(inA[blk_idx + row_idx].qs, 64 + 32 * iq + 8 * ir + i);
acc1[0] += yl[i+0] * (q1 & 0x000F);
acc1[1] += yl[i+1] * (q1 & 0x0F00);
acc1[2] += yl[i+8] * (q1 & 0x00F0);
acc1[3] += yl[i+9] * (q1 & 0xF000);
acc2[0] += yh[i+0] * (q2 & 0x000F);
acc2[1] += yh[i+1] * (q2 & 0x0F00);
acc2[2] += yh[i+8] * (q2 & 0x00F0);
acc2[3] += yh[i+9] * (q2 & 0xF000);
}
uint8_t sc8_0 = uint8_t(sc16[0] & 0xFF);
uint8_t sc8_1 = uint8_t(sc16[0] >> 8 );
uint8_t sc8_2 = uint8_t(sc16[1] & 0xFF);
uint8_t sc8_3 = uint8_t(sc16[1] >> 8 );
uint8_t sc8_4 = uint8_t(sc16[2] & 0xFF);
uint8_t sc8_5 = uint8_t(sc16[2] >> 8 );
uint8_t sc8_6 = uint8_t(sc16[3] & 0xFF);
uint8_t sc8_7 = uint8_t(sc16[3] >> 8 );
float dall = float(inA[blk_idx + row_idx].d);
float dmin = float(inA[blk_idx + row_idx].dmin);
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8_0 +
(acc1[2] + 1.f/256.f * acc1[3]) * sc8_1 * 1.f/16.f +
(acc2[0] + 1.f/256.f * acc2[1]) * sc8_4 +
(acc2[2] + 1.f/256.f * acc2[3]) * sc8_5 * 1.f/16.f) -
dmin * (sumy[0] * sc8_2 + sumy[1] * sc8_3 + sumy[2] * sc8_6 + sumy[3] * sc8_7);
}
y4 += 4 * QK_K;
}
for (int row = 0; row < N_DST; ++row) {
all_sum = subgroupAdd(sumf[row]);
if (subgroupElect()) {
out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row + pcs.outOff] = all_sum;
}
}
}

View File

@ -4273,8 +4273,11 @@ struct llama_model_loader {
llama_tensor_weight(const llama_file * file, uint16_t idx, const struct gguf_context * gguf_ctx, ggml_tensor * tensor) : idx(idx), tensor(tensor) {
const int tensor_idx = gguf_find_tensor(gguf_ctx, ggml_get_name(tensor));
offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx);
if (tensor_idx < 0) {
throw std::runtime_error(format("tensor '%s' not found in the model", ggml_get_name(tensor)));
}
offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx);
if (offs + ggml_nbytes(tensor) < offs || offs + ggml_nbytes(tensor) > file->size) {
throw std::runtime_error(format("tensor '%s' data is not within the file bounds, model is corrupted or incomplete", ggml_get_name(tensor)));
}
@ -4412,8 +4415,8 @@ struct llama_model_loader {
uint32_t n_type_max = 0;
enum ggml_type type_max = GGML_TYPE_F32;
for (auto it = weights_map.begin(); it != weights_map.end(); it++) {
const llama_tensor_weight & w = it->second;
for (const auto & it : weights_map) {
const llama_tensor_weight & w = it.second;
const ggml_tensor * tensor = w.tensor;
enum ggml_type type = tensor->type;
@ -4692,9 +4695,7 @@ struct llama_model_loader {
}
const llama_tensor_weight * get_weight(const char * name) const {
std::string tensor_name(name);
auto pos = weights_map.find(tensor_name);
auto pos = weights_map.find(name);
if (pos != weights_map.end()) {
return &pos->second;
}
@ -4832,8 +4833,8 @@ struct llama_model_loader {
}
// compute the total size of all tensors for progress reporting
for (auto it = weights_map.begin(); it != weights_map.end(); it++) {
size_data += ggml_nbytes(it->second.tensor);
for (const auto & it : weights_map) {
size_data += ggml_nbytes(it.second.tensor);
}
}
@ -7419,7 +7420,7 @@ static bool llm_load_tensors(
if (flags & llama_model_loader::TENSOR_NOT_REQUIRED) {
return nullptr;
}
throw std::runtime_error(format("missing tensor %s", tn.str().c_str()));
throw std::runtime_error(format("missing tensor '%s'", tn.str().c_str()));
}
// some models use the token embedding tensor as the output, but since these are used in different layers and with different ops
@ -18588,8 +18589,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
}
}
for (auto it = ml.weights_map.begin(); it != ml.weights_map.end(); ++it) {
const struct ggml_tensor * tensor = it->second.tensor;
for (const auto & it : ml.weights_map) {
const struct ggml_tensor * tensor = it.second.tensor;
const std::string name = ggml_get_name(tensor);
@ -18633,8 +18634,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
// Assume split index is continuous
if (params->keep_split) {
for (auto it = weights_map.begin(); it != weights_map.end(); ++it) {
n_split = std::max(uint16_t(it->second.idx+1), n_split);
for (const auto & it : weights_map) {
n_split = std::max(uint16_t(it.second.idx + 1), n_split);
}
}
@ -18642,9 +18643,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
ctx_outs[0] = ctx_out;
// populate the original tensors so we get an initial meta data
for (auto it = weights_map.begin(); it != weights_map.end(); ++it) {
uint16_t i_split = params->keep_split ? it->second.idx : 0;
struct ggml_tensor * tensor = it->second.tensor;
for (const auto & it : weights_map) {
uint16_t i_split = params->keep_split ? it.second.idx : 0;
struct ggml_tensor * tensor = it.second.tensor;
if (ctx_outs[i_split] == NULL) {
ctx_outs[i_split] = gguf_init_empty();
}
@ -18691,8 +18692,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
const auto tn = LLM_TN(model.arch);
new_ofstream(0);
for (auto it = weights_map.begin(); it != weights_map.end(); ++it) {
auto weight = it->second;
for (const auto & it : weights_map) {
const auto & weight = it.second;
struct ggml_tensor * tensor = weight.tensor;
if (weight.idx != cur_split && params->keep_split) {
close_ofstream();