diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 6357d4034..daad1c4fc 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -14,6 +14,7 @@ from pathlib import Path from hashlib import sha256 from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast +import math import numpy as np import torch @@ -1784,23 +1785,59 @@ class Phi3MiniModel(Model): def set_gguf_parameters(self): block_count = self.find_hparam(["num_hidden_layers", "n_layer"]) - rot_pct = 1.0 n_embd = self.find_hparam(["hidden_size", "n_embd"]) n_head = self.find_hparam(["num_attention_heads", "n_head"]) + n_head_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"]) rms_eps = self.find_hparam(["rms_norm_eps"]) + max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"]) + orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"]) + rope_dims = n_embd // n_head self.gguf_writer.add_name("Phi3") - self.gguf_writer.add_context_length(self.find_hparam(["n_positions", "max_position_embeddings"])) - + self.gguf_writer.add_context_length(max_pos_embds) + self.gguf_writer.add_rope_scaling_orig_ctx_len(orig_max_pos_embds) self.gguf_writer.add_embedding_length(n_embd) - self.gguf_writer.add_feed_forward_length(8192) + self.gguf_writer.add_feed_forward_length(self.find_hparam(["intermediate_size"])) self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_head_count(n_head) - self.gguf_writer.add_head_count_kv(n_head) + self.gguf_writer.add_head_count_kv(n_head_kv) self.gguf_writer.add_layer_norm_rms_eps(rms_eps) - self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head) + self.gguf_writer.add_rope_dimension_count(rope_dims) + self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"])) self.gguf_writer.add_file_type(self.ftype) + # write rope scaling for long context (128k) model + rope_scaling = self.find_hparam(['rope_scaling'], True) + if (rope_scaling is None): + return + + scale = max_pos_embds / orig_max_pos_embds + + rope_scaling_type = rope_scaling.get('type', '').lower() + if len(rope_scaling_type) == 0: + raise KeyError('Missing the required key rope_scaling.type') + + if rope_scaling_type == 'su': + attn_factor = math.sqrt(1 + math.log(scale) / math.log(orig_max_pos_embds)) if scale > 1.0 else 1.0 + elif rope_scaling_type == 'yarn': + attn_factor = 0.1 * math.log(scale) + 1.0 if scale > 1.0 else 1.0 + else: + raise NotImplementedError(f'The rope scaling type {rope_scaling_type} is not supported yet') + + self.gguf_writer.add_rope_scaling_attn_factors(attn_factor) + + long_factors = rope_scaling.get('long_factor', None) + short_factors = rope_scaling.get('short_factor', None) + + if long_factors is None or short_factors is None: + raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor') + + if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2: + raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}') + + self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32)) + self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32)) + @Model.register("PlamoForCausalLM") class PlamoModel(Model): diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index 22743b1bf..992426c1b 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -563,8 +563,8 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs( // not capturing these, to silcence warnings const int rope_mode = 0; - return ggml_rope_custom(ctx, - t, KQ_pos, n_rot, rope_mode, n_ctx, 0, + return ggml_rope_ext(ctx, + t, KQ_pos, nullptr, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f ); }; diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 587418cc7..45bdfa8f5 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -301,8 +301,8 @@ static struct ggml_tensor * llama_build_train_graphs( // not capturing these, to silcence warnings const int rope_mode = 0; - return ggml_rope_custom( - ctx, t, KQ_pos, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f + return ggml_rope_ext( + ctx, t, KQ_pos, nullptr, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f ); }; diff --git a/ggml-cuda/rope.cu b/ggml-cuda/rope.cu index 4b0d2e5ad..4a558f4b3 100644 --- a/ggml-cuda/rope.cu +++ b/ggml-cuda/rope.cu @@ -58,10 +58,10 @@ static __global__ void rope( dst[i + 1] = x0*sin_theta + x1*cos_theta; } -template +template static __global__ void rope_neox( const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, - float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims + float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims, const float * freq_factors ) { const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); @@ -88,7 +88,9 @@ static __global__ void rope_neox( float cur_rot = inv_ndims * ic - ib; const int p = has_pos ? pos[i2] : 0; - const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f); + const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f; + + const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f)/freq_factor; float cos_theta, sin_theta; rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); @@ -164,7 +166,7 @@ static void rope_cuda( template static void rope_neox_cuda( const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream ) { GGML_ASSERT(ncols % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); @@ -175,15 +177,29 @@ static void rope_neox_cuda( const float inv_ndims = -1.0f / n_dims; if (pos == nullptr) { - rope_neox<<>>( - x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, inv_ndims - ); + if (freq_factors == nullptr) { + rope_neox<<>>( + x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + theta_scale, inv_ndims, freq_factors + ); + } else { + rope_neox<<>>( + x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + theta_scale, inv_ndims, freq_factors + ); + } } else { - rope_neox<<>>( - x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, inv_ndims - ); + if (freq_factors == nullptr) { + rope_neox<<>>( + x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + theta_scale, inv_ndims, freq_factors + ); + } else { + rope_neox<<>>( + x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + theta_scale, inv_ndims, freq_factors + ); + } } } @@ -214,24 +230,27 @@ static void rope_cuda_f32( static void rope_neox_cuda_f16( const half * x, half * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) { + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { - rope_neox_cuda(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream); + rope_neox_cuda(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } static void rope_neox_cuda_f32( const float * x, float * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream ) { - rope_neox_cuda(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream); + rope_neox_cuda(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + const float * src0_d = (const float *)src0->data; const float * src1_d = (const float *)src1->data; + float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); @@ -241,7 +260,6 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; - const int64_t ne2 = dst->ne[2]; const int64_t nrows = ggml_nrows(src0); //const int n_past = ((int32_t *) dst->op_params)[0]; @@ -259,16 +277,22 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + const float * freq_factors = nullptr; const int32_t * pos = nullptr; - if ((mode & 1) == 0) { - GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(src1->ne[0] == ne2); - pos = (const int32_t *) src1_d; - } const bool is_neox = mode & 2; const bool is_glm = mode & 4; + if (is_neox) { + pos = (const int32_t *) src1_d; + + if (src2 != nullptr) { + freq_factors = (const float *) src2->data; + } + } else { + GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox"); + } + rope_corr_dims corr_dims; ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v); @@ -280,12 +304,12 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { if (src0->type == GGML_TYPE_F32) { rope_neox_cuda_f32( (const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, stream + attn_factor, corr_dims, freq_factors, stream ); } else if (src0->type == GGML_TYPE_F16) { rope_neox_cuda_f16( (const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, stream + attn_factor, corr_dims, freq_factors, stream ); } else { GGML_ASSERT(false); diff --git a/ggml-kompute.cpp b/ggml-kompute.cpp index 3f033d58b..6c6058b2a 100644 --- a/ggml-kompute.cpp +++ b/ggml-kompute.cpp @@ -1677,6 +1677,10 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml } break; case GGML_OP_ROPE: { +#pragma message("TODO: implement phi3 frequency factors support") +#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225") + GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet"); + GGML_ASSERT(ne10 == ne02); GGML_ASSERT(src0t == dstt); // const int n_past = ((int32_t *) dst->op_params)[0]; diff --git a/ggml-metal.m b/ggml-metal.m index b0b16dbf7..5d5ad20ad 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -927,22 +927,32 @@ static enum ggml_status ggml_metal_graph_compute( const int64_t ne10 = src1 ? src1->ne[0] : 0; const int64_t ne11 = src1 ? src1->ne[1] : 0; const int64_t ne12 = src1 ? src1->ne[2] : 0; - const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13); + const int64_t ne13 = src1 ? src1->ne[3] : 0; const uint64_t nb10 = src1 ? src1->nb[0] : 0; const uint64_t nb11 = src1 ? src1->nb[1] : 0; const uint64_t nb12 = src1 ? src1->nb[2] : 0; - const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13); + const uint64_t nb13 = src1 ? src1->nb[3] : 0; - const int64_t ne0 = dst ? dst->ne[0] : 0; - const int64_t ne1 = dst ? dst->ne[1] : 0; - const int64_t ne2 = dst ? dst->ne[2] : 0; - const int64_t ne3 = dst ? dst->ne[3] : 0; + const int64_t ne20 = src2 ? src2->ne[0] : 0; + const int64_t ne21 = src2 ? src2->ne[1] : 0; + const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22); + const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23); - const uint64_t nb0 = dst ? dst->nb[0] : 0; - const uint64_t nb1 = dst ? dst->nb[1] : 0; - const uint64_t nb2 = dst ? dst->nb[2] : 0; - const uint64_t nb3 = dst ? dst->nb[3] : 0; + const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20); + const uint64_t nb21 = src2 ? src2->nb[1] : 0; + const uint64_t nb22 = src2 ? src2->nb[2] : 0; + const uint64_t nb23 = src2 ? src2->nb[3] : 0; + + const int64_t ne0 = dst ? dst->ne[0] : 0; + const int64_t ne1 = dst ? dst->ne[1] : 0; + const int64_t ne2 = dst ? dst->ne[2] : 0; + const int64_t ne3 = dst ? dst->ne[3] : 0; + + const uint64_t nb0 = dst ? dst->nb[0] : 0; + const uint64_t nb1 = dst ? dst->nb[1] : 0; + const uint64_t nb2 = dst ? dst->nb[2] : 0; + const uint64_t nb3 = dst ? dst->nb[3] : 0; const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; @@ -1785,16 +1795,6 @@ static enum ggml_status ggml_metal_graph_compute( const int n_as = src0->ne[2]; // src2 = ids - const int64_t ne20 = src2->ne[0]; - const int64_t ne21 = src2->ne[1]; - const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22); - const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23); - - const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20); - const uint64_t nb21 = src2->nb[1]; - const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22); - const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23); - const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t); GGML_ASSERT(src2t == GGML_TYPE_I32); @@ -2244,7 +2244,13 @@ static enum ggml_status ggml_metal_graph_compute( // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); @@ -2252,6 +2258,15 @@ static enum ggml_status ggml_metal_graph_compute( memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + const bool is_neox = mode & 2; + const bool is_glm = mode & 4; + + GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal"); + + if (!is_neox) { + GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox"); + } + id pipeline = nil; switch (src0->type) { @@ -2263,33 +2278,38 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&n_past length:sizeof( int) atIndex:19]; - [encoder setBytes:&n_dims length:sizeof( int) atIndex:20]; - [encoder setBytes:&mode length:sizeof( int) atIndex:21]; - [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22]; - [encoder setBytes:&freq_base length:sizeof( float) atIndex:23]; - [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24]; - [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25]; - [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26]; - [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27]; - [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28]; + if (id_src2 != nil) { + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&n_past length:sizeof( int) atIndex:20]; + [encoder setBytes:&n_dims length:sizeof( int) atIndex:21]; + [encoder setBytes:&mode length:sizeof( int) atIndex:22]; + [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:23]; + [encoder setBytes:&freq_base length:sizeof( float) atIndex:24]; + [encoder setBytes:&freq_scale length:sizeof( float) atIndex:25]; + [encoder setBytes:&ext_factor length:sizeof( float) atIndex:26]; + [encoder setBytes:&attn_factor length:sizeof( float) atIndex:27]; + [encoder setBytes:&beta_fast length:sizeof( float) atIndex:28]; + [encoder setBytes:&beta_slow length:sizeof( float) atIndex:29]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; @@ -2535,11 +2555,6 @@ static enum ggml_status ggml_metal_graph_compute( GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) && "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); - const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20); - const uint64_t nb21 = src2 ? src2->nb[1] : 0; - const uint64_t nb22 = src2 ? src2->nb[2] : 0; - const uint64_t nb23 = src2 ? src2->nb[3] : 0; - const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); //const int64_t ne31 = src3 ? src3->ne[1] : 0; const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); diff --git a/ggml-metal.metal b/ggml-metal.metal index cf262e834..c5eb25280 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1640,6 +1640,7 @@ static void rope_yarn_corr_dims( typedef void (rope_t)( device const void * src0, device const int32_t * src1, + device const float * src2, device float * dst, constant int64_t & ne00, constant int64_t & ne01, @@ -1675,6 +1676,7 @@ template kernel void kernel_rope( device const void * src0, device const int32_t * src1, + device const float * src2, device float * dst, constant int64_t & ne00, constant int64_t & ne01, @@ -1744,8 +1746,10 @@ kernel void kernel_rope( // simplified from `(ib * n_dims + ic) * inv_ndims` const float cur_rot = inv_ndims*ic - ib; + const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f; + + const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor; - const float theta = theta_0 * pow(freq_base, cur_rot); float cos_theta, sin_theta; rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index eac8f5579..f486b6c0a 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -14454,6 +14454,9 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, const dpct::queue_ptr &main_stream) { +#pragma message("TODO: implement phi3 frequency factors support") +#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225") + GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet"); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index aff451b63..16287a280 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -4238,6 +4238,10 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx, } static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#pragma message("TODO: implement phi3 frequency factors support") +#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225") + GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet"); + const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; // const int n_ctx = ((int32_t *) dst->op_params)[3]; diff --git a/ggml.c b/ggml.c index 4bd911528..37b16b7a9 100644 --- a/ggml.c +++ b/ggml.c @@ -6231,6 +6231,7 @@ static struct ggml_tensor * ggml_rope_impl( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, + struct ggml_tensor * c, int n_dims, int mode, int n_ctx, @@ -6248,6 +6249,11 @@ static struct ggml_tensor * ggml_rope_impl( GGML_ASSERT(b->type == GGML_TYPE_I32); GGML_ASSERT(a->ne[2] == b->ne[0]); + if (c) { + GGML_ASSERT(c->type == GGML_TYPE_F32); + GGML_ASSERT(c->ne[0] >= n_dims / 2); + } + bool is_node = false; if (a->grad) { @@ -6271,6 +6277,7 @@ static struct ggml_tensor * ggml_rope_impl( result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = b; + result->src[2] = c; return result; } @@ -6283,7 +6290,7 @@ struct ggml_tensor * ggml_rope( int mode, int n_ctx) { return ggml_rope_impl( - ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false + ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false ); } @@ -6295,7 +6302,49 @@ struct ggml_tensor * ggml_rope_inplace( int mode, int n_ctx) { return ggml_rope_impl( - ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true + ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true + ); +} + +struct ggml_tensor * ggml_rope_ext( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + int n_dims, + int mode, + int n_ctx, + int n_orig_ctx, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + return ggml_rope_impl( + ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false + ); +} + +struct ggml_tensor * ggml_rope_ext_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + int n_dims, + int mode, + int n_ctx, + int n_orig_ctx, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + return ggml_rope_impl( + ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true ); } @@ -6314,7 +6363,7 @@ struct ggml_tensor * ggml_rope_custom( float beta_fast, float beta_slow) { return ggml_rope_impl( - ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, + ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false ); } @@ -6334,27 +6383,18 @@ struct ggml_tensor * ggml_rope_custom_inplace( float beta_fast, float beta_slow) { return ggml_rope_impl( - ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, + ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true ); } -struct ggml_tensor * ggml_rope_xpos_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int n_dims, - float base, - bool down) { - return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true); -} - // ggml_rope_back struct ggml_tensor * ggml_rope_back( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, + struct ggml_tensor * c, int n_dims, int mode, int n_ctx, @@ -6370,6 +6410,7 @@ struct ggml_tensor * ggml_rope_back( GGML_ASSERT(ggml_is_vector(b)); GGML_ASSERT(b->type == GGML_TYPE_I32); GGML_ASSERT(a->ne[2] == b->ne[0]); + GGML_ASSERT(c == NULL && "freq factors not implemented yet"); GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet"); @@ -14304,6 +14345,7 @@ static void ggml_compute_forward_rope_f32( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; + const struct ggml_tensor * src2 = dst->src[2]; if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; @@ -14363,6 +14405,17 @@ static void ggml_compute_forward_rope_f32( const bool is_neox = mode & 2; const bool is_glm = mode & 4; + const float * freq_factors = NULL; + if (is_neox) { + if (src2 != NULL) { + GGML_ASSERT(src2->type == GGML_TYPE_F32); + GGML_ASSERT(src2->ne[0] >= n_dims / 2); + freq_factors = (const float *) src2->data; + } + } else { + GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for mode 1"); + } + // backward process uses inverse rotation by cos and sin. // cos and sin build a rotation matrix, where the inverse is the transpose. // this essentially just switches the sign of sin. @@ -14439,10 +14492,11 @@ static void ggml_compute_forward_rope_f32( // simplified from `(ib * n_dims + ic) * inv_ndims` float cur_rot = inv_ndims * ic - ib; + float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f; float cos_theta, sin_theta; rope_yarn( - theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, + theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta ); sin_theta *= sin_sign; @@ -18387,6 +18441,7 @@ static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct gg static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set zero_table) { struct ggml_tensor * src0 = tensor->src[0]; struct ggml_tensor * src1 = tensor->src[1]; + struct ggml_tensor * src2 = tensor->src[2]; switch (tensor->op) { case GGML_OP_DUP: @@ -18918,6 +18973,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor ggml_rope_back(ctx, tensor->grad, src1, + src2, n_dims, mode, n_ctx, @@ -18957,6 +19013,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor ggml_rope_impl(ctx, tensor->grad, src1, + src2, n_dims, mode, n_ctx, @@ -19038,7 +19095,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor masked); } - struct ggml_tensor * src2 = tensor->src[2]; const int64_t elem_q = ggml_nelements(src0); const int64_t elem_k = ggml_nelements(src1); const int64_t elem_v = ggml_nelements(src2); diff --git a/ggml.h b/ggml.h index 774757101..35ac9110c 100644 --- a/ggml.h +++ b/ggml.h @@ -1465,6 +1465,7 @@ extern "C" { // if mode & 4 == 1, ChatGLM style // // b is an int32 vector with size a->ne[2], it contains the positions + // c is freq factors (e.g. phi3-128k), (optional) GGML_API struct ggml_tensor * ggml_rope( struct ggml_context * ctx, struct ggml_tensor * a, @@ -1483,10 +1484,11 @@ extern "C" { int n_ctx); // custom RoPE - GGML_API struct ggml_tensor * ggml_rope_custom( + GGML_API struct ggml_tensor * ggml_rope_ext( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, + struct ggml_tensor * c, int n_dims, int mode, int n_ctx, @@ -1499,7 +1501,23 @@ extern "C" { float beta_slow); // in-place, returns view(a) - GGML_API struct ggml_tensor * ggml_rope_custom_inplace( + GGML_API struct ggml_tensor * ggml_rope_ext_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + int n_dims, + int mode, + int n_ctx, + int n_orig_ctx, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, @@ -1512,20 +1530,28 @@ extern "C" { float ext_factor, float attn_factor, float beta_fast, - float beta_slow); + float beta_slow), + "use ggml_rope_ext instead"); - // compute correction dims for YaRN RoPE scaling - GGML_CALL void ggml_rope_yarn_corr_dims( - int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]); - - // xPos RoPE, in-place, returns view(a) - GGML_API struct ggml_tensor * ggml_rope_xpos_inplace( + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom_inplace( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, int n_dims, - float base, - bool down); + int mode, + int n_ctx, + int n_orig_ctx, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow), + "use ggml_rope_ext_inplace instead"); + + // compute correction dims for YaRN RoPE scaling + GGML_CALL void ggml_rope_yarn_corr_dims( + int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]); // rotary position embedding backward, i.e compute dx from dy // a - dy @@ -1533,6 +1559,7 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, + struct ggml_tensor * c, int n_dims, int mode, int n_ctx, diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 692120f4d..42df2e4d0 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -57,12 +57,13 @@ class Keys: CAUSAL = "{arch}.attention.causal" class Rope: - DIMENSION_COUNT = "{arch}.rope.dimension_count" - FREQ_BASE = "{arch}.rope.freq_base" - SCALING_TYPE = "{arch}.rope.scaling.type" - SCALING_FACTOR = "{arch}.rope.scaling.factor" - SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length" - SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" + DIMENSION_COUNT = "{arch}.rope.dimension_count" + FREQ_BASE = "{arch}.rope.freq_base" + SCALING_TYPE = "{arch}.rope.scaling.type" + SCALING_FACTOR = "{arch}.rope.scaling.factor" + SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor" + SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length" + SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" class SSM: CONV_KERNEL = "{arch}.ssm.conv_kernel" @@ -148,6 +149,8 @@ class MODEL_TENSOR(IntEnum): OUTPUT = auto() OUTPUT_NORM = auto() ROPE_FREQS = auto() + ROPE_FACTORS_LONG = auto() + ROPE_FACTORS_SHORT = auto() ATTN_Q = auto() ATTN_K = auto() ATTN_V = auto() @@ -225,6 +228,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.OUTPUT_NORM: "output_norm", MODEL_TENSOR.OUTPUT: "output", MODEL_TENSOR.ROPE_FREQS: "rope_freqs", + MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long", + MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short", MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2", MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv", diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index d5e323a52..8b41b54ea 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -433,6 +433,9 @@ class GGUFWriter: def add_rope_scaling_factor(self, value: float) -> None: self.add_float32(Keys.Rope.SCALING_FACTOR.format(arch=self.arch), value) + def add_rope_scaling_attn_factors(self, value: Sequence[float]) -> None: + self.add_float32(Keys.Rope.SCALING_ATTN_FACTOR.format(arch=self.arch), value) + def add_rope_scaling_orig_ctx_len(self, value: int) -> None: self.add_uint32(Keys.Rope.SCALING_ORIG_CTX_LEN.format(arch=self.arch), value) diff --git a/llama.cpp b/llama.cpp index d26fe559a..abff8c1c0 100644 --- a/llama.cpp +++ b/llama.cpp @@ -304,6 +304,7 @@ enum llm_kv { LLM_KV_ROPE_SCALE_LINEAR, LLM_KV_ROPE_SCALING_TYPE, LLM_KV_ROPE_SCALING_FACTOR, + LLM_KV_ROPE_SCALING_ATTN_FACTOR, LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, LLM_KV_ROPE_SCALING_FINETUNED, @@ -381,6 +382,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, + { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, @@ -436,6 +438,8 @@ enum llm_tensor { LLM_TENSOR_OUTPUT, LLM_TENSOR_OUTPUT_NORM, LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ROPE_FACTORS_LONG, + LLM_TENSOR_ROPE_FACTORS_SHORT, LLM_TENSOR_ATTN_Q, LLM_TENSOR_ATTN_K, LLM_TENSOR_ATTN_V, @@ -803,18 +807,20 @@ static const std::map> LLM_TENSOR_NA { LLM_ARCH_PHI3, { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" }, + { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, { @@ -1750,6 +1756,7 @@ struct llama_hparams { float f_norm_eps; float f_norm_rms_eps; + float rope_attn_factor = 1.0f; float rope_freq_base_train; float rope_freq_scale_train; uint32_t n_yarn_orig_ctx; @@ -1798,6 +1805,7 @@ struct llama_hparams { if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true; if (!is_float_close(this->f_norm_rms_eps, other.f_norm_rms_eps, EPSILON)) return true; + if (!is_float_close(this->rope_attn_factor, other.rope_attn_factor, EPSILON)) return true; if (!is_float_close(this->rope_freq_base_train, other.rope_freq_base_train, EPSILON)) return true; if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true; @@ -2103,6 +2111,10 @@ struct llama_model { struct ggml_tensor * output; struct ggml_tensor * output_b; + // long rope factors + struct ggml_tensor * rope_long = nullptr; + struct ggml_tensor * rope_short = nullptr; + std::vector layers; llama_split_mode split_mode; @@ -3306,6 +3318,39 @@ struct llama_model_loader { return get_arr_n(llm_kv(kid), result, required); } + template + bool get_arr(const std::string & key, std::vector & result, const bool required = true) { + const int kid = gguf_find_key(meta, key.c_str()); + + if (kid < 0) { + if (required) { + throw std::runtime_error(format("key not found in model: %s", key.c_str())); + } + return false; + } + + struct GGUFMeta::ArrayInfo arr_info = + GGUFMeta::GKV::get_kv(meta, kid); + + if (arr_info.gt != GGUF_TYPE_FLOAT32 && arr_info.gt != GGUF_TYPE_INT32) { + throw std::runtime_error(format("%s is not a float32 or int32 array", key.c_str())); + } + + // GGML_ASSERT(gguf_type_size(arr_info.gt) == sizeof(T)); + GGML_ASSERT((arr_info.gt != GGUF_TYPE_FLOAT32 || std::is_same::value)); + GGML_ASSERT((arr_info.gt != GGUF_TYPE_INT32 || std::is_same::value)); + + result.resize(arr_info.length); + result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length); + + return true; + } + + template + bool get_arr(const enum llm_kv kid, T& result, const bool required = true) { + return get_arr(llm_kv(kid), result, required); + } + template bool get_key(const std::string & key, T & result, const bool required = true) { auto it = kv_overrides.find(key); @@ -3849,6 +3894,8 @@ static void llm_load_hparams( } hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale; + ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false); + // sanity check for n_rot (optional) { hparams.n_rot = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head; @@ -4880,6 +4927,7 @@ static bool llm_load_tensors( // create tensors for the weights { const int64_t n_embd = hparams.n_embd; + const int64_t n_embd_head = n_embd / hparams.n_head; const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); const int64_t n_embd_gqa = n_embd_v_gqa; @@ -5591,6 +5639,9 @@ static bool llm_load_tensors( { model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }); + model.rope_long = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight"), { n_embd_head/2 }, false); + model.rope_short = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_embd_head/2 }, false); + // output { model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }); @@ -5601,12 +5652,12 @@ static bool llm_load_tensors( ggml_context* ctx_layer = ctx_for_layer(i); ggml_context* ctx_split = ctx_for_layer_split(i); - auto& layer = model.layers[i]; + auto & layer = model.layers[i]; layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }); layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, false); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }); layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }); @@ -6821,17 +6872,20 @@ struct llm_build_context { cb(lctx.inp_K_shift, "K_shift", -1); ggml_set_input(lctx.inp_K_shift); + struct ggml_tensor * rope_factors = build_rope_factors(); + for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * tmp = // we rotate only the first n_rot dimensions - ggml_rope_custom_inplace(ctx0, + ggml_rope_ext_inplace(ctx0, ggml_view_3d(ctx0, kv_self.k_l[il], n_embd_head_k, n_head_kv, n_ctx, ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k), ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), 0), - lctx.inp_K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + lctx.inp_K_shift, rope_factors, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + cb(tmp, "K_shifted", il); ggml_build_forward_expand(gf, tmp); } @@ -6934,6 +6988,17 @@ struct llm_build_context { return lctx.inp_pos; } + struct ggml_tensor * build_rope_factors() { + // choose long/short freq factors based on the context size + const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max; + + if (n_ctx_pre_seq > hparams.n_yarn_orig_ctx) { + return model.rope_long; + } + + return model.rope_short; + } + struct ggml_tensor * build_inp_out_ids() { lctx.inp_out_ids = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs); cb(lctx.inp_out_ids, "inp_out_ids", -1); @@ -7041,15 +7106,15 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -7171,13 +7236,13 @@ struct llm_build_context { switch (model.type) { case MODEL_7B: - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -7283,15 +7348,15 @@ struct llm_build_context { struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -7404,14 +7469,14 @@ struct llm_build_context { Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); // using mode = 2 for neox mode - Qcur = ggml_rope_custom( - ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -7527,15 +7592,15 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -7679,15 +7744,15 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -8032,15 +8097,15 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -8472,15 +8537,15 @@ struct llm_build_context { } - Qcur = ggml_rope_custom( - ctx0, Qcur, inp_pos, + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, Kcur, inp_pos, + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -8592,14 +8657,14 @@ struct llm_build_context { Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); // using mode = 2 for neox mode - Qcur = ggml_rope_custom( - ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -8703,15 +8768,15 @@ struct llm_build_context { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); cb(Vcur, "Vcur", il); - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -8817,15 +8882,15 @@ struct llm_build_context { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); cb(Vcur, "Vcur", il); - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -8969,8 +9034,8 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_custom( - ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); @@ -8980,8 +9045,8 @@ struct llm_build_context { Qcur = ggml_scale(ctx0, Qcur, 1.0f/sqrtf(float(n_embd_head))); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -9052,6 +9117,9 @@ struct llm_build_context { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + // rope freq factors for 128k context + struct ggml_tensor * rope_factors = build_rope_factors(); + for (int il = 0; il < n_layer; ++il) { auto residual = inpL; @@ -9088,8 +9156,8 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_custom( - ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); @@ -9097,8 +9165,8 @@ struct llm_build_context { Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -9204,14 +9272,14 @@ struct llm_build_context { struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens), inp_pos, nullptr, n_embd_head, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos, nullptr, n_embd_head, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); @@ -9412,15 +9480,15 @@ struct llm_build_context { cb(tmpk, "tmpk", il); cb(Vcur, "Vcur", il); - struct ggml_tensor * Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), inp_pos, + struct ggml_tensor * Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - struct ggml_tensor * Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos, + struct ggml_tensor * Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -9528,15 +9596,15 @@ struct llm_build_context { // cb(Vcur, "Vcur", il); // } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -9645,15 +9713,15 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -9775,15 +9843,15 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -9895,8 +9963,8 @@ struct llm_build_context { struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr, n_embd_head_k, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(Qcur, "Qcur", il); @@ -9904,8 +9972,8 @@ struct llm_build_context { Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k))); cb(Qcur, "Qcur_scaled", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr, n_embd_head_k, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); @@ -10015,15 +10083,15 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -10305,15 +10373,15 @@ struct llm_build_context { cb(Kcur, "Kcur", il); } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -10436,15 +10504,15 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -15417,6 +15485,7 @@ struct llama_context * llama_new_context_with_model( cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f; } + cparams.yarn_attn_factor *= hparams.rope_attn_factor; cparams.causal_attn = hparams.causal_attn; if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index c74e253db..1493a7ca7 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1763,14 +1763,14 @@ struct test_llama : public test_llm { struct ggml_tensor * Kcur = ggml_mul_mat(ctx, wk, cur); struct ggml_tensor * Vcur = ggml_mul_mat(ctx, wv, cur); - Qcur = ggml_rope_custom( - ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head, hp.n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head, hp.n_tokens), inp_pos, nullptr, hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); - Kcur = ggml_rope_custom( - ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos, nullptr, hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -1889,13 +1889,13 @@ struct test_falcon : public test_llm { Kcur = ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens); // using mode = 2 for neox mode - Qcur = ggml_rope_custom( - ctx, Qcur, inp_pos, hp.n_rot, 2, 0, hp.n_orig_ctx, + Qcur = ggml_rope_ext( + ctx, Qcur, inp_pos, nullptr, hp.n_rot, 2, 0, hp.n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); - Kcur = ggml_rope_custom( - ctx, Kcur, inp_pos, hp.n_rot, 2, 0, hp.n_orig_ctx, + Kcur = ggml_rope_ext( + ctx, Kcur, inp_pos, nullptr, hp.n_rot, 2, 0, hp.n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow );