llama : add phi3 128K model support (#7225)

* add phi3 128k support in convert-hf-to-gguf

* add phi3 128k support in cuda

* address build warnings on llama.cpp

* adjust index value in cuda long rope freq factors

* add long rope support in ggml cpu backend

* make freq factors only depend on ctx size

* remove unused rope scaling type 'su' frin gguf converter

* fix flint warnings on convert-hf-to-gguf.py

* set to the short freq factor when context size is small than trained context size

* add one line of comments

* metal : support rope freq_factors

* ggml : update ggml_rope_ext API to support freq. factors

* backends : add dev messages to support rope freq. factors

* minor : style

* tests : update to use new rope API

* backends : fix pragma semicolons

* minor : cleanup

* llama : move rope factors from KV header to tensors

* llama : remove tmp assert

* cuda : fix compile warning

* convert : read/write n_head_kv

* llama : fix uninitialized tensors

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
liuwei-git 2024-05-22 04:28:32 +08:00 committed by GitHub
parent 6369bf0433
commit 201cc11afa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 484 additions and 233 deletions

View File

@ -14,6 +14,7 @@ from pathlib import Path
from hashlib import sha256
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast
import math
import numpy as np
import torch
@ -1784,23 +1785,59 @@ class Phi3MiniModel(Model):
def set_gguf_parameters(self):
block_count = self.find_hparam(["num_hidden_layers", "n_layer"])
rot_pct = 1.0
n_embd = self.find_hparam(["hidden_size", "n_embd"])
n_head = self.find_hparam(["num_attention_heads", "n_head"])
n_head_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"])
rms_eps = self.find_hparam(["rms_norm_eps"])
max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"])
orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"])
rope_dims = n_embd // n_head
self.gguf_writer.add_name("Phi3")
self.gguf_writer.add_context_length(self.find_hparam(["n_positions", "max_position_embeddings"]))
self.gguf_writer.add_context_length(max_pos_embds)
self.gguf_writer.add_rope_scaling_orig_ctx_len(orig_max_pos_embds)
self.gguf_writer.add_embedding_length(n_embd)
self.gguf_writer.add_feed_forward_length(8192)
self.gguf_writer.add_feed_forward_length(self.find_hparam(["intermediate_size"]))
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_head_count(n_head)
self.gguf_writer.add_head_count_kv(n_head)
self.gguf_writer.add_head_count_kv(n_head_kv)
self.gguf_writer.add_layer_norm_rms_eps(rms_eps)
self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head)
self.gguf_writer.add_rope_dimension_count(rope_dims)
self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"]))
self.gguf_writer.add_file_type(self.ftype)
# write rope scaling for long context (128k) model
rope_scaling = self.find_hparam(['rope_scaling'], True)
if (rope_scaling is None):
return
scale = max_pos_embds / orig_max_pos_embds
rope_scaling_type = rope_scaling.get('type', '').lower()
if len(rope_scaling_type) == 0:
raise KeyError('Missing the required key rope_scaling.type')
if rope_scaling_type == 'su':
attn_factor = math.sqrt(1 + math.log(scale) / math.log(orig_max_pos_embds)) if scale > 1.0 else 1.0
elif rope_scaling_type == 'yarn':
attn_factor = 0.1 * math.log(scale) + 1.0 if scale > 1.0 else 1.0
else:
raise NotImplementedError(f'The rope scaling type {rope_scaling_type} is not supported yet')
self.gguf_writer.add_rope_scaling_attn_factors(attn_factor)
long_factors = rope_scaling.get('long_factor', None)
short_factors = rope_scaling.get('short_factor', None)
if long_factors is None or short_factors is None:
raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor')
if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2:
raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}')
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32))
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32))
@Model.register("PlamoForCausalLM")
class PlamoModel(Model):

View File

@ -563,8 +563,8 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
// not capturing these, to silcence warnings
const int rope_mode = 0;
return ggml_rope_custom(ctx,
t, KQ_pos, n_rot, rope_mode, n_ctx, 0,
return ggml_rope_ext(ctx,
t, KQ_pos, nullptr, n_rot, rope_mode, n_ctx, 0,
rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
);
};

View File

@ -301,8 +301,8 @@ static struct ggml_tensor * llama_build_train_graphs(
// not capturing these, to silcence warnings
const int rope_mode = 0;
return ggml_rope_custom(
ctx, t, KQ_pos, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
return ggml_rope_ext(
ctx, t, KQ_pos, nullptr, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
);
};

View File

@ -58,10 +58,10 @@ static __global__ void rope(
dst[i + 1] = x0*sin_theta + x1*cos_theta;
}
template<typename T, bool has_pos>
template<typename T, bool has_pos, bool has_freq_facs>
static __global__ void rope_neox(
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims, const float * freq_factors
) {
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
@ -88,7 +88,9 @@ static __global__ void rope_neox(
float cur_rot = inv_ndims * ic - ib;
const int p = has_pos ? pos[i2] : 0;
const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f);
const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f)/freq_factor;
float cos_theta, sin_theta;
rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
@ -164,7 +166,7 @@ static void rope_cuda(
template<typename T>
static void rope_neox_cuda(
const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
) {
GGML_ASSERT(ncols % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
@ -175,15 +177,29 @@ static void rope_neox_cuda(
const float inv_ndims = -1.0f / n_dims;
if (pos == nullptr) {
rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims
);
if (freq_factors == nullptr) {
rope_neox<T, false, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors
);
} else {
rope_neox<T, false, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors
);
}
} else {
rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims
);
if (freq_factors == nullptr) {
rope_neox<T, true, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors
);
} else {
rope_neox<T, true, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors
);
}
}
}
@ -214,24 +230,27 @@ static void rope_cuda_f32(
static void rope_neox_cuda_f16(
const half * x, half * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
rope_neox_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
rope_neox_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}
static void rope_neox_cuda_f32(
const float * x, float * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
) {
rope_neox_cuda<float>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
rope_neox_cuda<float>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
@ -241,7 +260,6 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne2 = dst->ne[2];
const int64_t nrows = ggml_nrows(src0);
//const int n_past = ((int32_t *) dst->op_params)[0];
@ -259,16 +277,22 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
const float * freq_factors = nullptr;
const int32_t * pos = nullptr;
if ((mode & 1) == 0) {
GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(src1->ne[0] == ne2);
pos = (const int32_t *) src1_d;
}
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
if (is_neox) {
pos = (const int32_t *) src1_d;
if (src2 != nullptr) {
freq_factors = (const float *) src2->data;
}
} else {
GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox");
}
rope_corr_dims corr_dims;
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
@ -280,12 +304,12 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
if (src0->type == GGML_TYPE_F32) {
rope_neox_cuda_f32(
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, stream
attn_factor, corr_dims, freq_factors, stream
);
} else if (src0->type == GGML_TYPE_F16) {
rope_neox_cuda_f16(
(const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, stream
attn_factor, corr_dims, freq_factors, stream
);
} else {
GGML_ASSERT(false);

View File

@ -1677,6 +1677,10 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
} break;
case GGML_OP_ROPE:
{
#pragma message("TODO: implement phi3 frequency factors support")
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
GGML_ASSERT(ne10 == ne02);
GGML_ASSERT(src0t == dstt);
// const int n_past = ((int32_t *) dst->op_params)[0];

View File

@ -927,22 +927,32 @@ static enum ggml_status ggml_metal_graph_compute(
const int64_t ne10 = src1 ? src1->ne[0] : 0;
const int64_t ne11 = src1 ? src1->ne[1] : 0;
const int64_t ne12 = src1 ? src1->ne[2] : 0;
const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
const int64_t ne13 = src1 ? src1->ne[3] : 0;
const uint64_t nb10 = src1 ? src1->nb[0] : 0;
const uint64_t nb11 = src1 ? src1->nb[1] : 0;
const uint64_t nb12 = src1 ? src1->nb[2] : 0;
const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
const uint64_t nb13 = src1 ? src1->nb[3] : 0;
const int64_t ne0 = dst ? dst->ne[0] : 0;
const int64_t ne1 = dst ? dst->ne[1] : 0;
const int64_t ne2 = dst ? dst->ne[2] : 0;
const int64_t ne3 = dst ? dst->ne[3] : 0;
const int64_t ne20 = src2 ? src2->ne[0] : 0;
const int64_t ne21 = src2 ? src2->ne[1] : 0;
const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22);
const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
const uint64_t nb0 = dst ? dst->nb[0] : 0;
const uint64_t nb1 = dst ? dst->nb[1] : 0;
const uint64_t nb2 = dst ? dst->nb[2] : 0;
const uint64_t nb3 = dst ? dst->nb[3] : 0;
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
const uint64_t nb23 = src2 ? src2->nb[3] : 0;
const int64_t ne0 = dst ? dst->ne[0] : 0;
const int64_t ne1 = dst ? dst->ne[1] : 0;
const int64_t ne2 = dst ? dst->ne[2] : 0;
const int64_t ne3 = dst ? dst->ne[3] : 0;
const uint64_t nb0 = dst ? dst->nb[0] : 0;
const uint64_t nb1 = dst ? dst->nb[1] : 0;
const uint64_t nb2 = dst ? dst->nb[2] : 0;
const uint64_t nb3 = dst ? dst->nb[3] : 0;
const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
@ -1785,16 +1795,6 @@ static enum ggml_status ggml_metal_graph_compute(
const int n_as = src0->ne[2];
// src2 = ids
const int64_t ne20 = src2->ne[0];
const int64_t ne21 = src2->ne[1];
const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20);
const uint64_t nb21 = src2->nb[1];
const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22);
const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23);
const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
GGML_ASSERT(src2t == GGML_TYPE_I32);
@ -2244,7 +2244,13 @@ static enum ggml_status ggml_metal_graph_compute(
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
float freq_base;
float freq_scale;
float ext_factor;
float attn_factor;
float beta_fast;
float beta_slow;
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
@ -2252,6 +2258,15 @@ static enum ggml_status ggml_metal_graph_compute(
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal");
if (!is_neox) {
GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox");
}
id<MTLComputePipelineState> pipeline = nil;
switch (src0->type) {
@ -2263,33 +2278,38 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
[encoder setBytes:&n_past length:sizeof( int) atIndex:19];
[encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
[encoder setBytes:&mode length:sizeof( int) atIndex:21];
[encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
[encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
if (id_src2 != nil) {
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
} else {
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
}
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
[encoder setBytes:&n_past length:sizeof( int) atIndex:20];
[encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
[encoder setBytes:&mode length:sizeof( int) atIndex:22];
[encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:23];
[encoder setBytes:&freq_base length:sizeof( float) atIndex:24];
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:25];
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:26];
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:27];
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:28];
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:29];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
@ -2535,11 +2555,6 @@ static enum ggml_status ggml_metal_graph_compute(
GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
const uint64_t nb23 = src2 ? src2->nb[3] : 0;
const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
//const int64_t ne31 = src3 ? src3->ne[1] : 0;
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);

View File

@ -1640,6 +1640,7 @@ static void rope_yarn_corr_dims(
typedef void (rope_t)(
device const void * src0,
device const int32_t * src1,
device const float * src2,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
@ -1675,6 +1676,7 @@ template<typename T>
kernel void kernel_rope(
device const void * src0,
device const int32_t * src1,
device const float * src2,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
@ -1744,8 +1746,10 @@ kernel void kernel_rope(
// simplified from `(ib * n_dims + ic) * inv_ndims`
const float cur_rot = inv_ndims*ic - ib;
const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f;
const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor;
const float theta = theta_0 * pow(freq_base, cur_rot);
float cos_theta, sin_theta;
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);

View File

@ -14454,6 +14454,9 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
const dpct::queue_ptr &main_stream) {
#pragma message("TODO: implement phi3 frequency factors support")
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);

View File

@ -4238,6 +4238,10 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx,
}
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
#pragma message("TODO: implement phi3 frequency factors support")
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
// const int n_ctx = ((int32_t *) dst->op_params)[3];

88
ggml.c
View File

@ -6231,6 +6231,7 @@ static struct ggml_tensor * ggml_rope_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
int n_dims,
int mode,
int n_ctx,
@ -6248,6 +6249,11 @@ static struct ggml_tensor * ggml_rope_impl(
GGML_ASSERT(b->type == GGML_TYPE_I32);
GGML_ASSERT(a->ne[2] == b->ne[0]);
if (c) {
GGML_ASSERT(c->type == GGML_TYPE_F32);
GGML_ASSERT(c->ne[0] >= n_dims / 2);
}
bool is_node = false;
if (a->grad) {
@ -6271,6 +6277,7 @@ static struct ggml_tensor * ggml_rope_impl(
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
result->src[2] = c;
return result;
}
@ -6283,7 +6290,7 @@ struct ggml_tensor * ggml_rope(
int mode,
int n_ctx) {
return ggml_rope_impl(
ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
);
}
@ -6295,7 +6302,49 @@ struct ggml_tensor * ggml_rope_inplace(
int mode,
int n_ctx) {
return ggml_rope_impl(
ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
);
}
struct ggml_tensor * ggml_rope_ext(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
int n_dims,
int mode,
int n_ctx,
int n_orig_ctx,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow) {
return ggml_rope_impl(
ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
);
}
struct ggml_tensor * ggml_rope_ext_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
int n_dims,
int mode,
int n_ctx,
int n_orig_ctx,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow) {
return ggml_rope_impl(
ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
);
}
@ -6314,7 +6363,7 @@ struct ggml_tensor * ggml_rope_custom(
float beta_fast,
float beta_slow) {
return ggml_rope_impl(
ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
);
}
@ -6334,27 +6383,18 @@ struct ggml_tensor * ggml_rope_custom_inplace(
float beta_fast,
float beta_slow) {
return ggml_rope_impl(
ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
);
}
struct ggml_tensor * ggml_rope_xpos_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int n_dims,
float base,
bool down) {
return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
}
// ggml_rope_back
struct ggml_tensor * ggml_rope_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
int n_dims,
int mode,
int n_ctx,
@ -6370,6 +6410,7 @@ struct ggml_tensor * ggml_rope_back(
GGML_ASSERT(ggml_is_vector(b));
GGML_ASSERT(b->type == GGML_TYPE_I32);
GGML_ASSERT(a->ne[2] == b->ne[0]);
GGML_ASSERT(c == NULL && "freq factors not implemented yet");
GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
@ -14304,6 +14345,7 @@ static void ggml_compute_forward_rope_f32(
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
const struct ggml_tensor * src2 = dst->src[2];
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
return;
@ -14363,6 +14405,17 @@ static void ggml_compute_forward_rope_f32(
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
const float * freq_factors = NULL;
if (is_neox) {
if (src2 != NULL) {
GGML_ASSERT(src2->type == GGML_TYPE_F32);
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
freq_factors = (const float *) src2->data;
}
} else {
GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for mode 1");
}
// backward process uses inverse rotation by cos and sin.
// cos and sin build a rotation matrix, where the inverse is the transpose.
// this essentially just switches the sign of sin.
@ -14439,10 +14492,11 @@ static void ggml_compute_forward_rope_f32(
// simplified from `(ib * n_dims + ic) * inv_ndims`
float cur_rot = inv_ndims * ic - ib;
float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
float cos_theta, sin_theta;
rope_yarn(
theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
&cos_theta, &sin_theta
);
sin_theta *= sin_sign;
@ -18387,6 +18441,7 @@ static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct gg
static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set zero_table) {
struct ggml_tensor * src0 = tensor->src[0];
struct ggml_tensor * src1 = tensor->src[1];
struct ggml_tensor * src2 = tensor->src[2];
switch (tensor->op) {
case GGML_OP_DUP:
@ -18918,6 +18973,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
ggml_rope_back(ctx,
tensor->grad,
src1,
src2,
n_dims,
mode,
n_ctx,
@ -18957,6 +19013,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
ggml_rope_impl(ctx,
tensor->grad,
src1,
src2,
n_dims,
mode,
n_ctx,
@ -19038,7 +19095,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
masked);
}
struct ggml_tensor * src2 = tensor->src[2];
const int64_t elem_q = ggml_nelements(src0);
const int64_t elem_k = ggml_nelements(src1);
const int64_t elem_v = ggml_nelements(src2);

49
ggml.h
View File

@ -1465,6 +1465,7 @@ extern "C" {
// if mode & 4 == 1, ChatGLM style
//
// b is an int32 vector with size a->ne[2], it contains the positions
// c is freq factors (e.g. phi3-128k), (optional)
GGML_API struct ggml_tensor * ggml_rope(
struct ggml_context * ctx,
struct ggml_tensor * a,
@ -1483,10 +1484,11 @@ extern "C" {
int n_ctx);
// custom RoPE
GGML_API struct ggml_tensor * ggml_rope_custom(
GGML_API struct ggml_tensor * ggml_rope_ext(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
int n_dims,
int mode,
int n_ctx,
@ -1499,7 +1501,23 @@ extern "C" {
float beta_slow);
// in-place, returns view(a)
GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
GGML_API struct ggml_tensor * ggml_rope_ext_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
int n_dims,
int mode,
int n_ctx,
int n_orig_ctx,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow);
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
@ -1512,20 +1530,28 @@ extern "C" {
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow);
float beta_slow),
"use ggml_rope_ext instead");
// compute correction dims for YaRN RoPE scaling
GGML_CALL void ggml_rope_yarn_corr_dims(
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
// xPos RoPE, in-place, returns view(a)
GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int n_dims,
float base,
bool down);
int mode,
int n_ctx,
int n_orig_ctx,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow),
"use ggml_rope_ext_inplace instead");
// compute correction dims for YaRN RoPE scaling
GGML_CALL void ggml_rope_yarn_corr_dims(
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
// rotary position embedding backward, i.e compute dx from dy
// a - dy
@ -1533,6 +1559,7 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
int n_dims,
int mode,
int n_ctx,

View File

@ -57,12 +57,13 @@ class Keys:
CAUSAL = "{arch}.attention.causal"
class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count"
FREQ_BASE = "{arch}.rope.freq_base"
SCALING_TYPE = "{arch}.rope.scaling.type"
SCALING_FACTOR = "{arch}.rope.scaling.factor"
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
DIMENSION_COUNT = "{arch}.rope.dimension_count"
FREQ_BASE = "{arch}.rope.freq_base"
SCALING_TYPE = "{arch}.rope.scaling.type"
SCALING_FACTOR = "{arch}.rope.scaling.factor"
SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
class SSM:
CONV_KERNEL = "{arch}.ssm.conv_kernel"
@ -148,6 +149,8 @@ class MODEL_TENSOR(IntEnum):
OUTPUT = auto()
OUTPUT_NORM = auto()
ROPE_FREQS = auto()
ROPE_FACTORS_LONG = auto()
ROPE_FACTORS_SHORT = auto()
ATTN_Q = auto()
ATTN_K = auto()
ATTN_V = auto()
@ -225,6 +228,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
MODEL_TENSOR.OUTPUT: "output",
MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long",
MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short",
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2",
MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv",

View File

@ -433,6 +433,9 @@ class GGUFWriter:
def add_rope_scaling_factor(self, value: float) -> None:
self.add_float32(Keys.Rope.SCALING_FACTOR.format(arch=self.arch), value)
def add_rope_scaling_attn_factors(self, value: Sequence[float]) -> None:
self.add_float32(Keys.Rope.SCALING_ATTN_FACTOR.format(arch=self.arch), value)
def add_rope_scaling_orig_ctx_len(self, value: int) -> None:
self.add_uint32(Keys.Rope.SCALING_ORIG_CTX_LEN.format(arch=self.arch), value)

277
llama.cpp
View File

@ -304,6 +304,7 @@ enum llm_kv {
LLM_KV_ROPE_SCALE_LINEAR,
LLM_KV_ROPE_SCALING_TYPE,
LLM_KV_ROPE_SCALING_FACTOR,
LLM_KV_ROPE_SCALING_ATTN_FACTOR,
LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
LLM_KV_ROPE_SCALING_FINETUNED,
@ -381,6 +382,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
{ LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" },
{ LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" },
{ LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" },
{ LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
{ LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },
@ -436,6 +438,8 @@ enum llm_tensor {
LLM_TENSOR_OUTPUT,
LLM_TENSOR_OUTPUT_NORM,
LLM_TENSOR_ROPE_FREQS,
LLM_TENSOR_ROPE_FACTORS_LONG,
LLM_TENSOR_ROPE_FACTORS_SHORT,
LLM_TENSOR_ATTN_Q,
LLM_TENSOR_ATTN_K,
LLM_TENSOR_ATTN_V,
@ -803,18 +807,20 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{
LLM_ARCH_PHI3,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" },
{ LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
@ -1750,6 +1756,7 @@ struct llama_hparams {
float f_norm_eps;
float f_norm_rms_eps;
float rope_attn_factor = 1.0f;
float rope_freq_base_train;
float rope_freq_scale_train;
uint32_t n_yarn_orig_ctx;
@ -1798,6 +1805,7 @@ struct llama_hparams {
if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true;
if (!is_float_close(this->f_norm_rms_eps, other.f_norm_rms_eps, EPSILON)) return true;
if (!is_float_close(this->rope_attn_factor, other.rope_attn_factor, EPSILON)) return true;
if (!is_float_close(this->rope_freq_base_train, other.rope_freq_base_train, EPSILON)) return true;
if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true;
@ -2103,6 +2111,10 @@ struct llama_model {
struct ggml_tensor * output;
struct ggml_tensor * output_b;
// long rope factors
struct ggml_tensor * rope_long = nullptr;
struct ggml_tensor * rope_short = nullptr;
std::vector<llama_layer> layers;
llama_split_mode split_mode;
@ -3306,6 +3318,39 @@ struct llama_model_loader {
return get_arr_n(llm_kv(kid), result, required);
}
template<typename T>
bool get_arr(const std::string & key, std::vector<T> & result, const bool required = true) {
const int kid = gguf_find_key(meta, key.c_str());
if (kid < 0) {
if (required) {
throw std::runtime_error(format("key not found in model: %s", key.c_str()));
}
return false;
}
struct GGUFMeta::ArrayInfo arr_info =
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta, kid);
if (arr_info.gt != GGUF_TYPE_FLOAT32 && arr_info.gt != GGUF_TYPE_INT32) {
throw std::runtime_error(format("%s is not a float32 or int32 array", key.c_str()));
}
// GGML_ASSERT(gguf_type_size(arr_info.gt) == sizeof(T));
GGML_ASSERT((arr_info.gt != GGUF_TYPE_FLOAT32 || std::is_same<T, float>::value));
GGML_ASSERT((arr_info.gt != GGUF_TYPE_INT32 || std::is_same<T, int>::value));
result.resize(arr_info.length);
result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length);
return true;
}
template<typename T>
bool get_arr(const enum llm_kv kid, T& result, const bool required = true) {
return get_arr(llm_kv(kid), result, required);
}
template<typename T>
bool get_key(const std::string & key, T & result, const bool required = true) {
auto it = kv_overrides.find(key);
@ -3849,6 +3894,8 @@ static void llm_load_hparams(
}
hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale;
ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false);
// sanity check for n_rot (optional)
{
hparams.n_rot = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head;
@ -4880,6 +4927,7 @@ static bool llm_load_tensors(
// create tensors for the weights
{
const int64_t n_embd = hparams.n_embd;
const int64_t n_embd_head = n_embd / hparams.n_head;
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
const int64_t n_embd_gqa = n_embd_v_gqa;
@ -5591,6 +5639,9 @@ static bool llm_load_tensors(
{
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab });
model.rope_long = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight"), { n_embd_head/2 }, false);
model.rope_short = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_embd_head/2 }, false);
// output
{
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd });
@ -5601,12 +5652,12 @@ static bool llm_load_tensors(
ggml_context* ctx_layer = ctx_for_layer(i);
ggml_context* ctx_split = ctx_for_layer_split(i);
auto& layer = model.layers[i];
auto & layer = model.layers[i];
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd });
layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, false);
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd });
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd });
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd });
@ -6821,17 +6872,20 @@ struct llm_build_context {
cb(lctx.inp_K_shift, "K_shift", -1);
ggml_set_input(lctx.inp_K_shift);
struct ggml_tensor * rope_factors = build_rope_factors();
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * tmp =
// we rotate only the first n_rot dimensions
ggml_rope_custom_inplace(ctx0,
ggml_rope_ext_inplace(ctx0,
ggml_view_3d(ctx0, kv_self.k_l[il],
n_embd_head_k, n_head_kv, n_ctx,
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
0),
lctx.inp_K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
lctx.inp_K_shift, rope_factors, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(tmp, "K_shifted", il);
ggml_build_forward_expand(gf, tmp);
}
@ -6934,6 +6988,17 @@ struct llm_build_context {
return lctx.inp_pos;
}
struct ggml_tensor * build_rope_factors() {
// choose long/short freq factors based on the context size
const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max;
if (n_ctx_pre_seq > hparams.n_yarn_orig_ctx) {
return model.rope_long;
}
return model.rope_short;
}
struct ggml_tensor * build_inp_out_ids() {
lctx.inp_out_ids = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs);
cb(lctx.inp_out_ids, "inp_out_ids", -1);
@ -7041,15 +7106,15 @@ struct llm_build_context {
cb(Vcur, "Vcur", il);
}
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
@ -7171,13 +7236,13 @@ struct llm_build_context {
switch (model.type) {
case MODEL_7B:
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
@ -7283,15 +7348,15 @@ struct llm_build_context {
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
@ -7404,14 +7469,14 @@ struct llm_build_context {
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
// using mode = 2 for neox mode
Qcur = ggml_rope_custom(
ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -7527,15 +7592,15 @@ struct llm_build_context {
cb(Vcur, "Vcur", il);
}
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
@ -7679,15 +7744,15 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
@ -8032,15 +8097,15 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
@ -8472,15 +8537,15 @@ struct llm_build_context {
}
Qcur = ggml_rope_custom(
ctx0, Qcur, inp_pos,
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, Kcur, inp_pos,
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
@ -8592,14 +8657,14 @@ struct llm_build_context {
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
// using mode = 2 for neox mode
Qcur = ggml_rope_custom(
ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -8703,15 +8768,15 @@ struct llm_build_context {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il);
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
@ -8817,15 +8882,15 @@ struct llm_build_context {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il);
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
@ -8969,8 +9034,8 @@ struct llm_build_context {
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Qcur = ggml_rope_custom(
ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
@ -8980,8 +9045,8 @@ struct llm_build_context {
Qcur = ggml_scale(ctx0, Qcur, 1.0f/sqrtf(float(n_embd_head)));
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -9052,6 +9117,9 @@ struct llm_build_context {
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
// rope freq factors for 128k context
struct ggml_tensor * rope_factors = build_rope_factors();
for (int il = 0; il < n_layer; ++il) {
auto residual = inpL;
@ -9088,8 +9156,8 @@ struct llm_build_context {
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Qcur = ggml_rope_custom(
ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
@ -9097,8 +9165,8 @@ struct llm_build_context {
Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head)));
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -9204,14 +9272,14 @@ struct llm_build_context {
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens), inp_pos,
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens), inp_pos, nullptr,
n_embd_head, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos,
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos, nullptr,
n_embd_head, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Kcur, "Kcur", il);
@ -9412,15 +9480,15 @@ struct llm_build_context {
cb(tmpk, "tmpk", il);
cb(Vcur, "Vcur", il);
struct ggml_tensor * Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), inp_pos,
struct ggml_tensor * Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
struct ggml_tensor * Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos,
struct ggml_tensor * Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
@ -9528,15 +9596,15 @@ struct llm_build_context {
// cb(Vcur, "Vcur", il);
// }
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
@ -9645,15 +9713,15 @@ struct llm_build_context {
cb(Vcur, "Vcur", il);
}
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
@ -9775,15 +9843,15 @@ struct llm_build_context {
cb(Vcur, "Vcur", il);
}
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
@ -9895,8 +9963,8 @@ struct llm_build_context {
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos,
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
n_embd_head_k, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il);
@ -9904,8 +9972,8 @@ struct llm_build_context {
Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));
cb(Qcur, "Qcur_scaled", il);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos,
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
n_embd_head_k, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Kcur, "Kcur", il);
@ -10015,15 +10083,15 @@ struct llm_build_context {
cb(Vcur, "Vcur", il);
}
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
@ -10305,15 +10373,15 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
}
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
@ -10436,15 +10504,15 @@ struct llm_build_context {
cb(Vcur, "Vcur", il);
}
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
@ -15417,6 +15485,7 @@ struct llama_context * llama_new_context_with_model(
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
}
cparams.yarn_attn_factor *= hparams.rope_attn_factor;
cparams.causal_attn = hparams.causal_attn;
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {

View File

@ -1763,14 +1763,14 @@ struct test_llama : public test_llm {
struct ggml_tensor * Kcur = ggml_mul_mat(ctx, wk, cur);
struct ggml_tensor * Vcur = ggml_mul_mat(ctx, wv, cur);
Qcur = ggml_rope_custom(
ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head, hp.n_tokens), inp_pos,
Qcur = ggml_rope_ext(
ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head, hp.n_tokens), inp_pos, nullptr,
hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_custom(
ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos,
Kcur = ggml_rope_ext(
ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos, nullptr,
hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
@ -1889,13 +1889,13 @@ struct test_falcon : public test_llm {
Kcur = ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens);
// using mode = 2 for neox mode
Qcur = ggml_rope_custom(
ctx, Qcur, inp_pos, hp.n_rot, 2, 0, hp.n_orig_ctx,
Qcur = ggml_rope_ext(
ctx, Qcur, inp_pos, nullptr, hp.n_rot, 2, 0, hp.n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_custom(
ctx, Kcur, inp_pos, hp.n_rot, 2, 0, hp.n_orig_ctx,
Kcur = ggml_rope_ext(
ctx, Kcur, inp_pos, nullptr, hp.n_rot, 2, 0, hp.n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);