ggml : change ggml_scale to take a float instead of tensor

This commit is contained in:
Georgi Gerganov 2023-12-21 20:50:24 +02:00
parent 8fe03ffdda
commit 199f6bdc46
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
10 changed files with 68 additions and 186 deletions

View File

@ -575,10 +575,7 @@ static struct ggml_tensor * forward(
// KQ_scaled = KQ / sqrt(n_embd/n_head)
// KQ_scaled shape [n_past + N, N, n_head, 1]
struct ggml_tensor * KQ_scaled =
ggml_scale(ctx0,
KQ,
ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, 1.0f/sqrtf(float(n_embd)/n_head));
// KQ_masked = mask_past(KQ_scaled)
// KQ_masked shape [n_past + N, N, n_head, 1]
@ -844,10 +841,7 @@ static struct ggml_tensor * forward_batch(
// KQ_scaled = KQ / sqrt(n_embd/n_head)
// KQ_scaled shape [n_past + N, N, n_head, n_batch]
struct ggml_tensor * KQ_scaled =
ggml_scale(ctx0,
KQ,
ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, 1.0f/sqrtf(float(n_embd)/n_head));
assert_shape_4d(KQ_scaled, n_past + N, N, n_head, n_batch);
// KQ_masked = mask_past(KQ_scaled)
@ -1131,10 +1125,7 @@ static struct ggml_tensor * forward_lora(
// KQ_scaled = KQ / sqrt(n_embd/n_head)
// KQ_scaled shape [n_past + N, N, n_head, 1]
struct ggml_tensor * KQ_scaled =
ggml_scale(ctx0,
KQ,
ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, 1.0f/sqrtf(float(n_embd)/n_head));
// KQ_masked = mask_past(KQ_scaled)
// KQ_masked shape [n_past + N, N, n_head, 1]

View File

@ -309,7 +309,7 @@ static struct ggml_cgraph * build_graph_lora(
) {
struct ggml_tensor * ab = ggml_mul_mat(ctx, lora_a, lora_b);
if (scaling != 1.0f) {
ab = ggml_scale(ctx, ab, ggml_new_f32(ctx, scaling));
ab = ggml_scale(ctx, ab, scaling);
}
struct ggml_tensor * res = ggml_add_inplace(ctx, tensor, ab);

View File

@ -269,7 +269,7 @@ static void load_model_hparams_gguf(struct gguf_context * ctx, struct my_llama_h
float rope_freq_scale = 1.0f;
GGUF_GET_KEY(ctx, hparams->f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
GGUF_GET_KEY(ctx, hparams->rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
GGUF_GET_KEY(ctx, rope_freq_scale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
GGUF_GET_KEY(ctx, rope_freq_scale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
if (rope_freq_scale != 1.0f) {
hparams->rope_freq_scale = 1.0f / rope_freq_scale;
}
@ -612,6 +612,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
const int n_rot = hparams.n_embd_head();
const int n_embd_head = hparams.n_embd_head();
const int n_embd_gqa = hparams.n_embd_gqa();
const float rms_norm_eps = hparams.f_norm_rms_eps;
const float rope_freq_base = hparams.rope_freq_base;
const float rope_freq_scale = hparams.rope_freq_scale;
@ -680,10 +681,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
checkpoints.push_back(t01);
}
struct ggml_tensor * kv_scale = NULL;
if (!enable_flash_attn) {
kv_scale = ggml_new_f32(ctx, 1.0f/sqrtf(float(n_embd)/n_head));
}
const float kv_scale = 1.0f/sqrtf(float(n_embd)/n_head);
for (int il = 0; il < n_layer; ++il) {
struct my_llama_layer & layer = model->layers[il];
@ -781,32 +779,32 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
// make sure some tensors are not reallocated by inserting new temporary nodes depending on them
int n_leafs_before = gb->n_leafs;
int n_nodes_before = gb->n_nodes;
struct ggml_tensor * one = ggml_new_f32(ctx, 1.0f);
// output tensors
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, 1.0f));
// input gradient
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, 1.0f));
GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL);
ggml_allocr_alloc(alloc, t36->grad);
// KQ_pos
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, 1.0f));
// make sure base model tensors data cannot be used in viewable operations
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->tok_embeddings, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->norm, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->output, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->tok_embeddings, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->norm, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->output, 1.0f));
for (int il = 0; il < n_layer; ++il) {
struct my_llama_layer & layer = model->layers[il];
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.attention_norm, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.ffn_norm, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wq, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wk, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wv, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wo, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w1, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w2, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w3, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.attention_norm, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.ffn_norm, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wq, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wk, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wv, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wo, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w1, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w2, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w3, 1.0f));
}
// allocating checkpoints in one block to reduce memory fragmentation

View File

@ -330,12 +330,6 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
ggml_repeat(ctx0, model.pre_ln_b, embeddings));
}
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
ggml_allocr_alloc(ctx->alloc, KQ_scale);
if (!ggml_allocr_is_measure(ctx->alloc)) {
ggml_set_f32(KQ_scale, 1.0f / sqrt((float)d_head));
}
// loop over layers
for (int il = 0; il < n_layer - 1; il++) {
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
@ -356,7 +350,7 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
struct ggml_tensor * Q =
ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].q_b, cur), ggml_mul_mat(ctx0, model.layers[il].q_w, cur));
Q = ggml_scale_inplace(ctx0, Q, KQ_scale);
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size);

View File

@ -369,10 +369,7 @@ static struct ggml_tensor * llama_build_train_graphs(
checkpoints.push_back(t00);
checkpoints.push_back(t01);
struct ggml_tensor * kv_scale = NULL;
if (!enable_flash_attn) {
kv_scale = ggml_new_f32(ctx, 1.0f/sqrtf(float(n_embd)/n_head));
}
const float kv_scale = 1.0f/sqrtf(float(n_embd)/n_head);
for (int il = 0; il < n_layer; ++il) {
struct my_llama_layer & layer = model->layers[il];
@ -444,14 +441,13 @@ static struct ggml_tensor * llama_build_train_graphs(
// make sure some tensors are not reallocated by inserting new temporary nodes depending on them
int n_leafs_before = gb->n_leafs;
int n_nodes_before = gb->n_nodes;
struct ggml_tensor * one = ggml_new_f32(ctx, 1.0f);
// output tensors
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, 1.0f));
// input gradient
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, 1.0f));
// KQ_pos
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, 1.0f));
GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL);
ggml_allocr_alloc(alloc, t36->grad);

View File

@ -7694,17 +7694,9 @@ inline void ggml_cuda_op_scale(
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
float scale;
// HACK: support for ggml backend interface
if (src1->backend == GGML_BACKEND_CPU) {
scale = ((float *) src1->data)[0];
} else {
// TODO: pass pointer to kernel instead of copying to host
CUDA_CHECK(cudaMemcpy(&scale, src1->data, sizeof(float), cudaMemcpyDeviceToHost));
}
const float scale = ((float *) dst->op_params)[0];
scale_f32_cuda(src0_dd, dst_dd, scale, ggml_nelements(src0), main_stream);
CUDA_CHECK(cudaGetLastError());
@ -7751,8 +7743,6 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_GPU;
const bool dst_on_device = dst->backend == GGML_BACKEND_GPU;
const bool src1_stays_on_host = use_src1 && dst->op == GGML_OP_SCALE;
// dd = data device
float * src0_ddf = nullptr;
float * src1_ddf = nullptr;
@ -7773,7 +7763,7 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf, src0, 0, 0, 0, nrows0, main_stream));
}
if (use_src1 && !src1_stays_on_host) {
if (use_src1) {
if (src1_on_device) {
src1_ddf = (float *) src1_extra->data_device[g_main_device];
} else {

View File

@ -1261,7 +1261,7 @@ void ggml_metal_graph_compute(
{
GGML_ASSERT(ggml_is_contiguous(src0));
const float scale = *(const float *) src1->data;
const float scale = *(const float *) dst->op_params;
int64_t n = ggml_nelements(dst);
@ -1272,8 +1272,8 @@ void ggml_metal_graph_compute(
[encoder setComputePipelineState:ctx->pipeline_scale];
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];

33
ggml.c
View File

@ -4183,23 +4183,23 @@ struct ggml_tensor * ggml_out_prod(
static struct ggml_tensor * ggml_scale_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
float s,
bool inplace) {
GGML_ASSERT(ggml_is_scalar(b));
GGML_ASSERT(ggml_is_padded_1d(a));
bool is_node = false;
if (a->grad || b->grad) {
if (a->grad) {
is_node = true;
}
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
ggml_set_op_params(result, &s, sizeof(s));
result->op = GGML_OP_SCALE;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
@ -4207,15 +4207,15 @@ static struct ggml_tensor * ggml_scale_impl(
struct ggml_tensor * ggml_scale(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
return ggml_scale_impl(ctx, a, b, false);
float s) {
return ggml_scale_impl(ctx, a, s, false);
}
struct ggml_tensor * ggml_scale_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
return ggml_scale_impl(ctx, a, b, true);
float s) {
return ggml_scale_impl(ctx, a, s, true);
}
// ggml_set
@ -14851,7 +14851,7 @@ static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct gg
static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, struct ggml_hash_set zero_table) {
if (ggml_hash_contains(zero_table, a)) {
struct ggml_tensor * a_zero = ggml_scale(ctx, a, ggml_new_f32(ctx, 0));
struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f);
return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
} else {
return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
@ -14987,7 +14987,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src0->grad,
ggml_scale(ctx,
ggml_mul(ctx, src0, tensor->grad),
ggml_new_f32(ctx, 2.0f)),
2.0f),
zero_table);
}
} break;
@ -15001,7 +15001,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
ggml_div(ctx,
tensor->grad,
tensor),
ggml_new_f32(ctx, 0.5f)),
0.5f),
zero_table);
}
} break;
@ -15167,17 +15167,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
// necessary for llama
if (src0->grad) {
const float s = ((float *) tensor->op_params)[0];
src0->grad =
ggml_add_or_set(ctx,
src0->grad,
ggml_scale_impl(ctx, tensor->grad, src1, false),
zero_table);
}
if (src1->grad) {
src1->grad =
ggml_add_or_set(ctx,
src1->grad,
ggml_sum(ctx, ggml_mul_impl(ctx, tensor->grad, src0, false)),
ggml_scale_impl(ctx, tensor->grad, s, false),
zero_table);
}
} break;

4
ggml.h
View File

@ -1094,13 +1094,13 @@ extern "C" {
GGML_API struct ggml_tensor * ggml_scale(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
float s);
// in-place, returns view(a)
GGML_API struct ggml_tensor * ggml_scale_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
float s);
// b -> view(a,offset,nb1,nb2,3), return modified a
GGML_API struct ggml_tensor * ggml_set(

116
llama.cpp
View File

@ -4088,13 +4088,12 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * wo,
struct ggml_tensor * wo_b,
struct ggml_tensor * q_cur,
struct ggml_tensor * kq_scale,
struct ggml_tensor * kq_mask,
int64_t n_ctx,
int32_t n_tokens,
int32_t n_kv,
float max_alibi_bias,
float scale,
float kq_scale,
const llm_build_cb & cb,
int il) {
const int64_t n_embd = hparams.n_embd;
@ -4142,7 +4141,7 @@ static struct ggml_tensor * llm_build_kqv(
kq = ggml_soft_max(ctx, kq);
cb(kq, "kq_soft_max", il);
} else {
kq = ggml_soft_max_ext(ctx, kq, kq_mask, scale);
kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale);
cb(kq, "kq_soft_max_ext", il);
}
@ -4287,10 +4286,6 @@ struct llm_build_context {
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
cb(inp_pos, "inp_pos", -1);
// KQ_scale
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
cb(KQ_scale, "KQ_scale", -1);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
cb(KQ_mask, "KQ_mask", -1);
@ -4351,7 +4346,7 @@ struct llm_build_context {
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
model.layers[il].wo, model.layers[il].bo,
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}
@ -4472,10 +4467,6 @@ struct llm_build_context {
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
cb(inp_pos, "inp_pos", -1);
// KQ_scale
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
cb(KQ_scale, "KQ_scale", -1);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
cb(KQ_mask, "KQ_mask", -1);
@ -4534,7 +4525,7 @@ struct llm_build_context {
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
model.layers[il].wo, NULL,
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
Qcur, KQ_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}
@ -4592,10 +4583,6 @@ struct llm_build_context {
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
cb(inp_pos, "inp_pos", -1);
// KQ_scale
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
cb(KQ_scale, "KQ_scale", -1);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
cb(KQ_mask, "KQ_mask", -1);
@ -4658,7 +4645,7 @@ struct llm_build_context {
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
model.layers[il].wo, NULL,
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}
@ -4715,10 +4702,6 @@ struct llm_build_context {
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
cb(inp_pos, "inp_pos", -1);
// KQ_scale
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
cb(KQ_scale, "KQ_scale", -1);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
cb(KQ_mask, "KQ_mask", -1);
@ -4758,7 +4741,7 @@ struct llm_build_context {
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
model.layers[il].wo, model.layers[il].bo,
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}
@ -4815,10 +4798,6 @@ struct llm_build_context {
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
cb(inp_pos, "inp_pos", -1);
// KQ_scale
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
cb(KQ_scale, "KQ_scale", -1);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
cb(KQ_mask, "KQ_mask", -1);
@ -4967,7 +4946,7 @@ struct llm_build_context {
// TODO: not tested, could be broken
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
model.layers[il].wo, model.layers[il].bo,
Q, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
Q, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}
@ -5021,10 +5000,6 @@ struct llm_build_context {
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
cb(inpL, "inp_embd", -1);
// KQ_scale
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
cb(KQ_scale, "KQ_scale", -1);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
cb(KQ_mask, "KQ_mask", -1);
@ -5058,7 +5033,7 @@ struct llm_build_context {
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
model.layers[il].wo, NULL,
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
Qcur, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}
@ -5112,10 +5087,6 @@ struct llm_build_context {
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
cb(inpL, "inp_embd", -1);
// KQ_scale
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
cb(KQ_scale, "KQ_scale", -1);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
cb(KQ_mask, "KQ_mask", -1);
@ -5155,7 +5126,7 @@ struct llm_build_context {
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
model.layers[il].wo, model.layers[il].bo,
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
Qcur, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}
@ -5206,10 +5177,6 @@ struct llm_build_context {
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
cb(inpL, "inp_embd", -1);
// KQ_scale
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
cb(KQ_scale, "KQ_scale", -1);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
cb(KQ_mask, "KQ_mask", -1);
@ -5249,7 +5216,7 @@ struct llm_build_context {
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
model.layers[il].wo, NULL,
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, hparams.f_max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
Qcur, KQ_mask, n_ctx, n_tokens, n_kv, hparams.f_max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}
@ -5309,10 +5276,6 @@ struct llm_build_context {
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
cb(inp_pos, "inp_pos", -1);
// KQ_scale
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
cb(KQ_scale, "KQ_scale", -1);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
cb(KQ_mask, "KQ_mask", -1);
@ -5362,7 +5325,7 @@ struct llm_build_context {
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
model.layers[il].wo, NULL,
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}
@ -5422,10 +5385,6 @@ struct llm_build_context {
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
cb(inp_pos, "inp_pos", -1);
// KQ_scale
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
cb(KQ_scale, "KQ_scale", -1);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
cb(KQ_mask, "KQ_mask", -1);
@ -5479,7 +5438,7 @@ struct llm_build_context {
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
model.layers[il].wo, NULL,
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}
@ -5538,14 +5497,6 @@ struct llm_build_context {
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
cb(inp_pos, "inp_pos", -1);
// Q_scale
struct ggml_tensor * Q_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
cb(Q_scale, "Q_scale", -1);
// KQ_scale
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
cb(KQ_scale, "KQ_scale", -1);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
cb(KQ_mask, "KQ_mask", -1);
@ -5587,7 +5538,9 @@ struct llm_build_context {
);
cb(Qcur, "Qcur", il);
Qcur = ggml_scale(ctx0, Qcur, Q_scale);
// with phi2, we scale the Q to avoid precision issues
// ref: https://github.com/ml-explore/mlx-examples/blob/08e862336ade809bc37d1035f94b359e7d1a5152/phi2/phi2.py#L64-L66
Qcur = ggml_scale(ctx0, Qcur, 1.0f/sqrtf(float(n_embd_head)));
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
@ -5600,7 +5553,7 @@ struct llm_build_context {
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
model.layers[il].wo, model.layers[il].bo,
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f, cb, il);
Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f, cb, il);
cb(cur, "kqv_out", il);
}
@ -5737,8 +5690,6 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
{ "pos_embd", OFFLOAD_FUNC_NR },
{ "inp_pos", OFFLOAD_FUNC_FRC }, // this is often used for KQ ops (e.g. rope)
{ "Q_scale", OFFLOAD_FUNC_FRC },
{ "KQ_scale", OFFLOAD_FUNC_FRC },
{ "KQ_mask", OFFLOAD_FUNC_FRC },
{ "K_shift", OFFLOAD_FUNC_FRC },
@ -5840,8 +5791,6 @@ static struct ggml_cgraph * llama_build_graph(
bool alloc_inp_tokens = false;
bool alloc_inp_embd = false;
bool alloc_inp_pos = false;
bool alloc_inp_Q_scale = false;
bool alloc_inp_KQ_scale = false;
bool alloc_inp_KQ_mask = false;
bool alloc_inp_K_shift = false;
@ -5908,34 +5857,6 @@ static struct ggml_cgraph * llama_build_graph(
alloc_inp_pos = true;
}
if (!alloc_inp_Q_scale && strcmp(name, "Q_scale") == 0) {
ggml_allocr_alloc(lctx.alloc, cur);
if (!ggml_allocr_is_measure(lctx.alloc)) {
const int64_t n_embd_head = model.hparams.n_embd_head();
ggml_set_f32(cur, 1.0f/sqrtf(float(n_embd_head)));
}
alloc_inp_Q_scale = true;
}
if (!alloc_inp_KQ_scale && strcmp(name, "KQ_scale") == 0) {
ggml_allocr_alloc(lctx.alloc, cur);
if (!ggml_allocr_is_measure(lctx.alloc)) {
const int64_t n_embd_head = model.hparams.n_embd_head();
if (model.arch == LLM_ARCH_PHI2) {
// with phi2, we scale the Q to avoid precision issues
// ref: https://github.com/ml-explore/mlx-examples/blob/08e862336ade809bc37d1035f94b359e7d1a5152/phi2/phi2.py#L64-L66
ggml_set_f32(cur, 1.0f);
} else {
ggml_set_f32(cur, 1.0f/sqrtf(float(n_embd_head)));
}
}
alloc_inp_KQ_scale = true;
}
if (!alloc_inp_KQ_mask && strcmp(name, "KQ_mask") == 0) {
ggml_allocr_alloc(lctx.alloc, cur);
@ -9113,10 +9034,7 @@ static int llama_apply_lora_from_file_internal(
ggml_set_name(BA, "BA");
if (scaling != 1.0f) {
ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx.get(), scaling);
ggml_set_name(scale_tensor, "scale_tensor");
BA = ggml_scale_inplace(lora_ctx.get(), BA, scale_tensor);
BA = ggml_scale_inplace(lora_ctx.get(), BA, scaling);
offload_func(BA);
ggml_set_name(BA, "BA_scaled");
}