Merge branch 'ggerganov:master' into master

This commit is contained in:
Steward Garcia 2023-10-15 18:23:47 -04:00 committed by GitHub
commit f47fd17b73
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 75 additions and 30 deletions

View File

@ -98,6 +98,8 @@ gguf_writer.add_embedding_length(hparams["d_model"])
gguf_writer.add_block_count(block_count) gguf_writer.add_block_count(block_count)
gguf_writer.add_feed_forward_length(4 * hparams["d_model"]) gguf_writer.add_feed_forward_length(4 * hparams["d_model"])
gguf_writer.add_head_count(hparams["n_heads"]) gguf_writer.add_head_count(hparams["n_heads"])
if kv_n_heads := hparams["attn_config"].get("kv_n_heads"):
gguf_writer.add_head_count_kv(kv_n_heads)
gguf_writer.add_layer_norm_eps(1e-05) gguf_writer.add_layer_norm_eps(1e-05)
if hparams["attn_config"]["clip_qkv"] is not None: if hparams["attn_config"]["clip_qkv"] is not None:
gguf_writer.add_clamp_kqv(hparams["attn_config"]["clip_qkv"]) gguf_writer.add_clamp_kqv(hparams["attn_config"]["clip_qkv"])

View File

@ -529,13 +529,14 @@ static void init_lora(const struct my_llama_model * model, struct my_llama_lora
set_param_lora(lora); set_param_lora(lora);
// measure data size // measure data size
struct ggml_allocr * alloc = NULL; size_t size = 0;
alloc = ggml_allocr_new_measure(tensor_alignment); for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
alloc_lora(alloc, lora); size += GGML_PAD(ggml_nbytes(t), tensor_alignment);
}
// allocate data // allocate data
lora->data.resize(ggml_allocr_max_size(alloc) + tensor_alignment); struct ggml_allocr * alloc = NULL;
ggml_allocr_free(alloc); lora->data.resize(size + tensor_alignment);
alloc = ggml_allocr_new(lora->data.data(), lora->data.size(), tensor_alignment); alloc = ggml_allocr_new(lora->data.data(), lora->data.size(), tensor_alignment);
alloc_lora(alloc, lora); alloc_lora(alloc, lora);
ggml_allocr_free(alloc); ggml_allocr_free(alloc);
@ -1714,11 +1715,9 @@ int main(int argc, char ** argv) {
struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx_input, GGML_TYPE_F32, n_vocab, n_tokens, n_batch); struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx_input, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
// measure required memory for input tensors // measure required memory for input tensors
alloc = ggml_allocr_new_measure(tensor_alignment); size_t max_input_size = GGML_PAD(ggml_nbytes(tokens_input), tensor_alignment) +
ggml_allocr_alloc(alloc, tokens_input); GGML_PAD(ggml_nbytes(target_probs), tensor_alignment) +
ggml_allocr_alloc(alloc, target_probs); tensor_alignment;
size_t max_input_size = ggml_allocr_max_size(alloc) + tensor_alignment;
ggml_allocr_free(alloc);
printf("%s: input_size = %zu bytes (%.1f MB)\n", __func__, max_input_size, (float) max_input_size / (1024.0f*1024.0f)); printf("%s: input_size = %zu bytes (%.1f MB)\n", __func__, max_input_size, (float) max_input_size / (1024.0f*1024.0f));
// allocate input tensors // allocate input tensors

View File

@ -79,7 +79,13 @@ int main(int argc, char ** argv) {
llama_backend_init(params.numa); llama_backend_init(params.numa);
llama_model_params model_params = llama_model_default_params(); llama_model_params model_params = llama_model_default_params();
model_params.n_gpu_layers = params.n_gpu_layers;
model_params.main_gpu = params.main_gpu;
model_params.tensor_split = params.tensor_split;
model_params.use_mmap = params.use_mmap;
model_params.use_mlock = params.use_mlock;
llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
if (model == NULL) { if (model == NULL) {
fprintf(stderr , "%s: error: unable to load model\n" , __func__); fprintf(stderr , "%s: error: unable to load model\n" , __func__);

View File

@ -19,7 +19,7 @@
#pragma warning(disable: 4244 4267) // possible loss of data #pragma warning(disable: 4244 4267) // possible loss of data
#endif #endif
#define CL_DMMV_BLOCK_SIZE 32 #define CL_DMMV_LOCAL_SIZE 32
#ifndef K_QUANTS_PER_ITERATION #ifndef K_QUANTS_PER_ITERATION
#define K_QUANTS_PER_ITERATION 1 #define K_QUANTS_PER_ITERATION 1
@ -338,7 +338,7 @@ __kernel void dequantize_mul_mat_vec_q2_K(__global const struct block_q2_K * xx,
const int row = get_group_id(0); const int row = get_group_id(0);
const int num_blocks_per_row = ncols / QK_K; const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row; const int ib0 = row*num_blocks_per_row + get_global_offset(0);
__global const struct block_q2_K * x = xx + ib0; __global const struct block_q2_K * x = xx + ib0;
@ -413,7 +413,7 @@ __kernel void dequantize_mul_mat_vec_q3_K(__global const struct block_q3_K * xx,
const int row = get_group_id(0); const int row = get_group_id(0);
const int num_blocks_per_row = ncols / QK_K; const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row; const int ib0 = row*num_blocks_per_row + get_global_offset(0);
__global const struct block_q3_K * x = xx + ib0; __global const struct block_q3_K * x = xx + ib0;
@ -489,7 +489,7 @@ __kernel void dequantize_mul_mat_vec_q4_K(__global const struct block_q4_K * xx,
const int row = get_group_id(0); const int row = get_group_id(0);
const int num_blocks_per_row = ncols / QK_K; const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row; const int ib0 = row*num_blocks_per_row + get_global_offset(0);
const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...15 const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...15
const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION; const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION;
@ -562,7 +562,7 @@ __kernel void dequantize_mul_mat_vec_q5_K(__global const struct block_q5_K * xx,
const int row = get_group_id(0); const int row = get_group_id(0);
const int num_blocks_per_row = ncols / QK_K; const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row; const int ib0 = row*num_blocks_per_row + get_global_offset(0);
const int tid = get_local_id(0)/2; // 0...15 const int tid = get_local_id(0)/2; // 0...15
const int ix = get_local_id(0)%2; const int ix = get_local_id(0)%2;
@ -641,7 +641,7 @@ __kernel void dequantize_mul_mat_vec_q6_K(__global const struct block_q6_K * xx,
const int row = get_group_id(0); const int row = get_group_id(0);
const int num_blocks_per_row = ncols / QK_K; const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row; const int ib0 = row*num_blocks_per_row + get_global_offset(0);
__global const struct block_q6_K * x = xx + ib0; __global const struct block_q6_K * x = xx + ib0;
@ -745,19 +745,21 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __global float* y) {
std::string dequant_mul_mat_vec_template = MULTILINE_QUOTE( std::string dequant_mul_mat_vec_template = MULTILINE_QUOTE(
__kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) { __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) {
const int block_size = get_local_size(0); const int local_size = get_local_size(0);
const int row = get_group_id(0); const int row = get_group_id(0);
const int tid = get_local_id(0); const int tid = get_local_id(0);
const uint qk = QUANT_K; const uint qk = QUANT_K;
const uint qr = QUANT_R; const uint qr = QUANT_R;
const int col_step = local_size * 2;
const int y_offset = qr == 1 ? 1 : qk/2; const int y_offset = qr == 1 ? 1 : qk/2;
x += get_global_offset(0);
tmp[tid] = 0; tmp[tid] = 0;
for (int i = 0; i < ncols/block_size; i += 2) { for (int col = tid*2; col < ncols; col += col_step) {
const int col = i*block_size + 2*tid;
const int ib = (row*ncols + col)/qk; // block index const int ib = (row*ncols + col)/qk; // block index
const int iqs = (col%qk)/qr; // quant index const int iqs = (col%qk)/qr; // quant index
const int iybs = col - col%qk; // y block start index const int iybs = col - col%qk; // y block start index
@ -773,7 +775,7 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float
// sum up partial sums and write back result // sum up partial sums and write back result
barrier(CLK_LOCAL_MEM_FENCE); barrier(CLK_LOCAL_MEM_FENCE);
for (int s=block_size/2; s>0; s>>=1) { for (int s=local_size/2; s>0; s>>=1) {
if (tid < s) { if (tid < s) {
tmp[tid] += tmp[tid + s]; tmp[tid] += tmp[tid + s];
} }
@ -1704,7 +1706,7 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
const int nb2 = dst->nb[2]; const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3]; const int nb3 = dst->nb[3];
const ggml_type type = src0->type; const ggml_type type = src0->type;
const bool mul_mat_vec = ne11 == 1; const bool mul_mat_vec = ne11 == 1 && ne00%2 == 0;
const int64_t r2 = ne12 / ne02; const int64_t r2 = ne12 / ne02;
const int64_t r3 = ne13 / ne03; const int64_t r3 = ne13 / ne03;
@ -1737,7 +1739,7 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
GGML_ASSERT(to_fp32_cl != nullptr); GGML_ASSERT(to_fp32_cl != nullptr);
const size_t global_denom = ggml_cl_global_denom(type); const size_t global_denom = ggml_cl_global_denom(type);
const size_t local = ggml_cl_local_size(type); const size_t local = mul_mat_vec ? CL_DMMV_LOCAL_SIZE : ggml_cl_local_size(type);
size_t ev_idx = 0; size_t ev_idx = 0;
std::vector<cl_event> events; std::vector<cl_event> events;
@ -1770,8 +1772,8 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, events.data() + ev_idx++)); CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, events.data() + ev_idx++));
// compute // compute
const size_t global = ne01 * CL_DMMV_BLOCK_SIZE; const size_t global = ne01 * local;
const size_t local = CL_DMMV_BLOCK_SIZE; const size_t offset = src0->backend == GGML_BACKEND_GPU ? (i03 * ne02 + i02) * x_bps : 0;
const cl_int ncols = ne00; const cl_int ncols = ne00;
events.emplace_back(); events.emplace_back();
CL_CHECK(clSetKernelArg(*dmmv, 0, sizeof(cl_mem), &d_Q)); CL_CHECK(clSetKernelArg(*dmmv, 0, sizeof(cl_mem), &d_Q));
@ -1779,7 +1781,7 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
CL_CHECK(clSetKernelArg(*dmmv, 2, sizeof(cl_mem), &d_Y)); CL_CHECK(clSetKernelArg(*dmmv, 2, sizeof(cl_mem), &d_Y));
CL_CHECK(clSetKernelArg(*dmmv, 3, sizeof(cl_mem), &d_D)); CL_CHECK(clSetKernelArg(*dmmv, 3, sizeof(cl_mem), &d_D));
CL_CHECK(clSetKernelArg(*dmmv, 4, sizeof(cl_int), &ncols)); CL_CHECK(clSetKernelArg(*dmmv, 4, sizeof(cl_int), &ncols));
CL_CHECK(clEnqueueNDRangeKernel(queue, *dmmv, 1, NULL, &global, &local, events.size() - 1, events.data(), events.data() + ev_idx++)); CL_CHECK(clEnqueueNDRangeKernel(queue, *dmmv, 1, &offset, &global, &local, events.size() - 1, events.data(), events.data() + ev_idx++));
} else { // general dequantization kernel + CLBlast matrix matrix multiplication } else { // general dequantization kernel + CLBlast matrix matrix multiplication
// convert src0 to fp32 on device // convert src0 to fp32 on device
const size_t global = x_ne / global_denom; const size_t global = x_ne / global_denom;

34
ggml.c
View File

@ -5494,6 +5494,39 @@ struct ggml_tensor * ggml_view_tensor(
return result; return result;
} }
struct ggml_tensor * ggml_get_first_tensor(struct ggml_context * ctx) {
struct ggml_object * obj = ctx->objects_begin;
char * const mem_buffer = ctx->mem_buffer;
while (obj != NULL) {
if (obj->type == GGML_OBJECT_TENSOR) {
return (struct ggml_tensor *)(mem_buffer + obj->offs);
}
obj = obj->next;
}
return NULL;
}
struct ggml_tensor * ggml_get_next_tensor(struct ggml_context * ctx, struct ggml_tensor * tensor) {
struct ggml_object * obj = (struct ggml_object *) ((char *)tensor - GGML_OBJECT_SIZE);
obj = obj->next;
char * const mem_buffer = ctx->mem_buffer;
while (obj != NULL) {
if (obj->type == GGML_OBJECT_TENSOR) {
return (struct ggml_tensor *)(mem_buffer + obj->offs);
}
obj = obj->next;
}
return NULL;
}
struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name) { struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name) {
struct ggml_object * obj = ctx->objects_begin; struct ggml_object * obj = ctx->objects_begin;
@ -8647,6 +8680,7 @@ void ggml_set_param(
GGML_ASSERT(tensor->grad == NULL); GGML_ASSERT(tensor->grad == NULL);
tensor->grad = ggml_dup_tensor(ctx, tensor); tensor->grad = ggml_dup_tensor(ctx, tensor);
ggml_format_name(tensor->grad, "%s (grad)", tensor->name);
} }
// ggml_compute_forward_dup // ggml_compute_forward_dup

3
ggml.h
View File

@ -704,6 +704,9 @@ extern "C" {
GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src); GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, struct ggml_tensor * src); GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, struct ggml_tensor * src);
// Context tensor enumeration and lookup
GGML_API struct ggml_tensor * ggml_get_first_tensor(struct ggml_context * ctx);
GGML_API struct ggml_tensor * ggml_get_next_tensor (struct ggml_context * ctx, struct ggml_tensor * tensor);
GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name); GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name);
GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor); GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);

View File

@ -2839,8 +2839,8 @@ static void llm_load_tensors(
auto & layer = model.layers[i]; auto & layer = model.layers[i];
layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3*n_embd}, backend_split); layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
@ -5368,7 +5368,7 @@ static struct ggml_cgraph * llm_build_mpt(
const int64_t n_layer = hparams.n_layer; const int64_t n_layer = hparams.n_layer;
const int64_t n_ctx = cparams.n_ctx; const int64_t n_ctx = cparams.n_ctx;
const int64_t n_head = hparams.n_head; const int64_t n_head = hparams.n_head;
const int64_t n_head_kv = hparams.n_head_kv; // == n_head for MPT, as there's no MQA/GQA const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_embd_head = hparams.n_embd_head(); const int64_t n_embd_head = hparams.n_embd_head();
const int64_t n_embd_gqa = hparams.n_embd_gqa(); const int64_t n_embd_gqa = hparams.n_embd_gqa();
@ -5721,7 +5721,6 @@ static struct ggml_cgraph * llama_build_graph(
// //
// - lctx: llama context // - lctx: llama context
// - batch: batch to evaluate // - batch: batch to evaluate
// - n_threads: number of threads to use
// //
// return 0 on success // return 0 on success
// return positive int on warning // return positive int on warning