diff --git a/examples/baby-llama/baby-llama.cpp b/examples/baby-llama/baby-llama.cpp index bf0125e75..4f6c3746a 100644 --- a/examples/baby-llama/baby-llama.cpp +++ b/examples/baby-llama/baby-llama.cpp @@ -522,8 +522,8 @@ static struct ggml_tensor * forward( // wk shape [n_embd, n_embd, 1, 1] // Qcur shape [n_embd/n_head, n_head, N, 1] // Kcur shape [n_embd/n_head, n_head, N, 1] - struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0, 0); - struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0, 0); + struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0); + struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0); // store key and value to memory { @@ -759,8 +759,8 @@ static struct ggml_tensor * forward_batch( // wk shape [n_embd, n_embd, 1, 1] // Qcur shape [n_embd/n_head, n_head, N, n_batch] // Kcur shape [n_embd/n_head, n_head, N, n_batch] - struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0, 0); - struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0, 0); + struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0); + struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0); assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch); assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch); @@ -1056,7 +1056,7 @@ static struct ggml_tensor * forward_lora( model->layers[il].wqb, cur)), n_embd/n_head, n_head, N), - KQ_pos, n_rot, 0, 0); + KQ_pos, n_rot, 0); struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, @@ -1065,7 +1065,7 @@ static struct ggml_tensor * forward_lora( model->layers[il].wkb, cur)), n_embd/n_head, n_head, N), - KQ_pos, n_rot, 0, 0); + KQ_pos, n_rot, 0); // store key and value to memory { diff --git a/examples/convert-legacy-llama.py b/examples/convert-legacy-llama.py index fd8401015..721a57c00 100755 --- a/examples/convert-legacy-llama.py +++ b/examples/convert-legacy-llama.py @@ -176,7 +176,7 @@ class Params: rope_scaling_type: gguf.RopeScalingType | None = None f_rope_freq_base: float | None = None f_rope_scale: float | None = None - n_orig_ctx: int | None = None + n_ctx_orig: int | None = None rope_finetuned: bool | None = None ftype: GGMLFileType | None = None @@ -226,7 +226,7 @@ class Params: with open(config_path) as f: config = json.load(f) - rope_scaling_type = f_rope_scale = n_orig_ctx = rope_finetuned = None + rope_scaling_type = f_rope_scale = n_ctx_orig = rope_finetuned = None rope_scaling = config.get("rope_scaling") if rope_scaling is not None and (typ := rope_scaling.get("type")): @@ -236,7 +236,7 @@ class Params: rope_scaling_type = gguf.RopeScalingType.LINEAR elif typ == "yarn": rope_scaling_type = gguf.RopeScalingType.YARN - n_orig_ctx = rope_scaling['original_max_position_embeddings'] + n_ctx_orig = rope_scaling['original_max_position_embeddings'] rope_finetuned = rope_scaling['finetuned'] else: raise NotImplementedError(f'Unknown rope scaling type: {typ}') @@ -272,7 +272,7 @@ class Params: f_rope_freq_base = config.get("rope_theta"), rope_scaling_type = rope_scaling_type, f_rope_scale = f_rope_scale, - n_orig_ctx = n_orig_ctx, + n_ctx_orig = n_ctx_orig, rope_finetuned = rope_finetuned, ) @@ -864,8 +864,8 @@ class OutputFile: self.gguf.add_rope_scaling_type(params.rope_scaling_type) self.gguf.add_rope_scaling_factor(params.f_rope_scale) - if params.n_orig_ctx is not None: - self.gguf.add_rope_scaling_orig_ctx_len(params.n_orig_ctx) + if params.n_ctx_orig is not None: + self.gguf.add_rope_scaling_orig_ctx_len(params.n_ctx_orig) if params.rope_finetuned is not None: self.gguf.add_rope_scaling_finetuned(params.rope_finetuned) diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index 22425730f..71a4333ee 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -564,7 +564,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs( const int rope_mode = 0; return ggml_rope_ext(ctx, - t, KQ_pos, nullptr, n_rot, rope_mode, n_ctx, 0, + t, KQ_pos, nullptr, n_rot, rope_mode, n_ctx, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f ); }; diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index e2f85c682..b779f6bd4 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -302,7 +302,7 @@ static struct ggml_tensor * llama_build_train_graphs( const int rope_mode = 0; return ggml_rope_ext( - ctx, t, KQ_pos, nullptr, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f + ctx, t, KQ_pos, nullptr, n_rot, rope_mode, n_ctx, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f ); }; diff --git a/ggml-cuda/rope.cu b/ggml-cuda/rope.cu index 0dd07977e..596fb7c13 100644 --- a/ggml-cuda/rope.cu +++ b/ggml-cuda/rope.cu @@ -1,7 +1,7 @@ #include "rope.cuh" struct rope_corr_dims { - float v[4]; + float v[2]; }; static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) { @@ -13,8 +13,7 @@ static __device__ float rope_yarn_ramp(const float low, const float high, const // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. static __device__ void rope_yarn( float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale, - float * cos_theta, float * sin_theta -) { + float * cos_theta, float * sin_theta) { // Get n-d rotational scaling corrected for extrapolation float theta_interp = freq_scale * theta_extrap; float theta = theta_interp; @@ -29,27 +28,38 @@ static __device__ void rope_yarn( *sin_theta = sinf(theta) * mscale; } -// rope == RoPE == rotary positional embedding -template -static __global__ void rope( - const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, - float ext_factor, float attn_factor, rope_corr_dims corr_dims -) { - const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); +template +static __global__ void rope_norm( + const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, + float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) { + const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); - if (col >= ncols) { + if (i0 >= ne0) { return; } const int row = blockDim.x*blockIdx.x + threadIdx.x; - const int i = row*ncols + col; + + if (i0 >= n_dims) { + const int i = row*ne0 + i0; + + dst[i + 0] = x[i + 0]; + dst[i + 1] = x[i + 1]; + + return; + } + + const int i = row*ne0 + i0; const int i2 = row/p_delta_rows; - const int p = has_pos ? pos[i2] : 0; - const float theta_base = p*powf(freq_base, -float(col)/ncols); + const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); - float cos_theta, sin_theta; - rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta); + const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; + + float cos_theta; + float sin_theta; + + rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); const float x0 = x[i + 0]; const float x1 = x[i + 1]; @@ -58,23 +68,20 @@ static __global__ void rope( dst[i + 1] = x0*sin_theta + x1*cos_theta; } -template +template static __global__ void rope_neox( - const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, - float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors -) { - const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); + const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, + float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) { + const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); - if (col >= ncols) { + if (i0 >= ne0) { return; } const int row = blockDim.x*blockIdx.x + threadIdx.x; - const int ib = col / n_dims; - const int ic = col % n_dims; - if (ib > 0) { - const int i = row*ncols + ib*n_dims + ic; + if (i0 >= n_dims) { + const int i = row*ne0 + i0; dst[i + 0] = x[i + 0]; dst[i + 1] = x[i + 1]; @@ -82,16 +89,17 @@ static __global__ void rope_neox( return; } - const int i = row*ncols + ib*n_dims + ic/2; + const int i = row*ne0 + i0/2; const int i2 = row/p_delta_rows; - const int p = has_pos ? pos[i2] : 0; - const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f; + const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); - const float theta_base = p*powf(theta_scale, col/2.0f)/freq_factor; + const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; - float cos_theta, sin_theta; - rope_yarn(theta_base, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta); + float cos_theta; + float sin_theta; + + rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); const float x0 = x[i + 0]; const float x1 = x[i + n_dims/2]; @@ -100,144 +108,81 @@ static __global__ void rope_neox( dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta; } -static __global__ void rope_glm_f32( - const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, - int n_ctx -) { - const int col = blockDim.x*blockIdx.x + threadIdx.x; - const int half_n_dims = ncols/4; - - if (col >= half_n_dims) { - return; - } - - const int row = blockDim.y*blockIdx.y + threadIdx.y; - const int i = row*ncols + col; - const int i2 = row/p_delta_rows; - - const float col_theta_scale = powf(freq_base, -2.0f*col/ncols); - // FIXME: this is likely wrong - const int p = pos != nullptr ? pos[i2] : 0; - - const float theta = min(p, n_ctx - 2)*freq_scale*col_theta_scale; - const float sin_theta = sinf(theta); - const float cos_theta = cosf(theta); - - const float x0 = x[i + 0]; - const float x1 = x[i + half_n_dims]; - - dst[i + 0] = x0*cos_theta - x1*sin_theta; - dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta; - - const float block_theta = ((float)max(p - n_ctx - 2, 0))*col_theta_scale; - const float sin_block_theta = sinf(block_theta); - const float cos_block_theta = cosf(block_theta); - - const float x2 = x[i + half_n_dims * 2]; - const float x3 = x[i + half_n_dims * 3]; - - dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta; - dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta; -} - - template -static void rope_cuda( - const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream -) { - GGML_ASSERT(ncols % 2 == 0); +static void rope_norm_cuda( + const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { + GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); - const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); - const dim3 block_nums(nrows, num_blocks_x, 1); - if (pos == nullptr) { - rope<<>>( - x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims - ); + const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const dim3 block_nums(nr, n_blocks_x, 1); + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + + if (freq_factors == nullptr) { + rope_norm<<>>( + x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors + ); } else { - rope<<>>( - x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims - ); + rope_norm<<>>( + x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors + ); } } template static void rope_neox_cuda( - const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream -) { - GGML_ASSERT(ncols % 2 == 0); + const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { + GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); - const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); - const dim3 block_nums(nrows, num_blocks_x, 1); + const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const dim3 block_nums(nr, n_blocks_x, 1); const float theta_scale = powf(freq_base, -2.0f/n_dims); - if (pos == nullptr) { - if (freq_factors == nullptr) { - rope_neox<<>>( - x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + if (freq_factors == nullptr) { + rope_neox<<>>( + x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors ); - } else { - rope_neox<<>>( - x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, freq_factors - ); - } } else { - if (freq_factors == nullptr) { - rope_neox<<>>( - x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + rope_neox<<>>( + x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors ); - } else { - rope_neox<<>>( - x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, freq_factors - ); - } } } -static void rope_glm_f32_cuda( - const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, int n_ctx, cudaStream_t stream -) { - GGML_ASSERT(ncols % 4 == 0); - const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1); - const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE; - const dim3 block_nums(num_blocks_x, nrows, 1); - rope_glm_f32<<>>(x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, n_ctx); +static void rope_norm_cuda_f16( + const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { + + rope_norm_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } -static void rope_cuda_f16( - const half * x, half * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) { +static void rope_norm_cuda_f32( + const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { - rope_cuda(x, dst, ncols, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream); -} - -static void rope_cuda_f32( - const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) { - - rope_cuda(x, dst, ncols, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream); + rope_norm_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } static void rope_neox_cuda_f16( - const half * x, half * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, + const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { - rope_neox_cuda(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); + rope_neox_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } static void rope_neox_cuda_f32( - const float * x, float * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, + const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream ) { - rope_neox_cuda(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); + rope_neox_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -258,16 +203,22 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; - const int64_t nrows = ggml_nrows(src0); + const int64_t nr = ggml_nrows(src0); - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; - const int n_ctx = ((int32_t *) dst->op_params)[3]; - const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; + //const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + //const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; // RoPE alteration for extended context - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); @@ -275,38 +226,28 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); - const float * freq_factors = nullptr; - const int32_t * pos = nullptr; - const bool is_neox = mode & 2; - const bool is_glm = mode & 4; - pos = (const int32_t *) src1_d; + const int32_t * pos = (const int32_t *) src1_d; - if (is_neox) { - if (src2 != nullptr) { - freq_factors = (const float *) src2->data; - } - } else { - GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox"); + const float * freq_factors = nullptr; + if (src2 != nullptr) { + freq_factors = (const float *) src2->data; } rope_corr_dims corr_dims; - ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v); + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v); // compute - if (is_glm) { - GGML_ASSERT(false); - rope_glm_f32_cuda(src0_d, dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, stream); - } else if (is_neox) { + if (is_neox) { if (src0->type == GGML_TYPE_F32) { rope_neox_cuda_f32( - (const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor, + (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream ); } else if (src0->type == GGML_TYPE_F16) { rope_neox_cuda_f16( - (const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor, + (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream ); } else { @@ -314,14 +255,14 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { } } else { if (src0->type == GGML_TYPE_F32) { - rope_cuda_f32( - (const float *)src0_d, (float *)dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, stream + rope_norm_cuda_f32( + (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, + attn_factor, corr_dims, freq_factors, stream ); } else if (src0->type == GGML_TYPE_F16) { - rope_cuda_f16( - (const half *)src0_d, (half *)dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, stream + rope_norm_cuda_f16( + (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, + attn_factor, corr_dims, freq_factors, stream ); } else { GGML_ASSERT(false); diff --git a/ggml-kompute.cpp b/ggml-kompute.cpp index eabd70d5e..5592741be 100644 --- a/ggml-kompute.cpp +++ b/ggml-kompute.cpp @@ -1192,7 +1192,7 @@ static void ggml_vk_rope( const std::shared_ptr& inB, const std::shared_ptr& out, uint32_t inAOff, uint32_t inBOff, uint32_t outOff, - ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_orig_ctx, + ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_ctx_orig, float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow, int32_t ne01, int32_t ne02, int32_t ne03, uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03, @@ -1221,14 +1221,14 @@ static void ggml_vk_rope( struct PushConstants { uint32_t inAOff, inBOff, outOff; - int32_t n_dims, mode, n_orig_ctx; + int32_t n_dims, mode, n_ctx_orig; float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; uint32_t nb00, nb01, nb02, nb03; int32_t ne0; uint32_t nb0, nb1, nb2, nb3; } pushConsts { safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size), - n_dims, mode, n_orig_ctx, + n_dims, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, nb00, nb01, nb02, nb03, ne0, @@ -1692,13 +1692,16 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225") GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet"); +#pragma message("TODO: update rope NORM mode to match NEOX mode") +#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634") + GGML_ASSERT(ne10 == ne02); GGML_ASSERT(src0t == dstt); // const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; // skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan - const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); @@ -1708,7 +1711,7 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); ggml_vk_rope( - seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_orig_ctx, + seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3 ); diff --git a/ggml-metal.m b/ggml-metal.m index fddc44f78..946f11813 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -172,8 +172,10 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, - GGML_METAL_KERNEL_TYPE_ROPE_F32, - GGML_METAL_KERNEL_TYPE_ROPE_F16, + GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, + GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, + GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, + GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, GGML_METAL_KERNEL_TYPE_IM2COL_F16, GGML_METAL_KERNEL_TYPE_IM2COL_F32, GGML_METAL_KERNEL_TYPE_UPSCALE_F32, @@ -626,8 +628,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); @@ -2285,7 +2289,7 @@ static enum ggml_status ggml_metal_graph_compute( const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal - const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; float freq_base; float freq_scale; @@ -2302,21 +2306,22 @@ static enum ggml_status ggml_metal_graph_compute( memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); const bool is_neox = mode & 2; - const bool is_glm = mode & 4; - - GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal"); - - if (!is_neox) { - GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox"); - } id pipeline = nil; - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break; - default: GGML_ASSERT(false); - }; + if (!is_neox) { + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; + default: GGML_ASSERT(false); + }; + } else { + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; + default: GGML_ASSERT(false); + }; + } [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -2345,14 +2350,13 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19]; [encoder setBytes:&n_past length:sizeof( int) atIndex:20]; [encoder setBytes:&n_dims length:sizeof( int) atIndex:21]; - [encoder setBytes:&mode length:sizeof( int) atIndex:22]; - [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:23]; - [encoder setBytes:&freq_base length:sizeof( float) atIndex:24]; - [encoder setBytes:&freq_scale length:sizeof( float) atIndex:25]; - [encoder setBytes:&ext_factor length:sizeof( float) atIndex:26]; - [encoder setBytes:&attn_factor length:sizeof( float) atIndex:27]; - [encoder setBytes:&beta_fast length:sizeof( float) atIndex:28]; - [encoder setBytes:&beta_slow length:sizeof( float) atIndex:29]; + [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22]; + [encoder setBytes:&freq_base length:sizeof( float) atIndex:23]; + [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24]; + [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25]; + [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26]; + [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27]; + [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; diff --git a/ggml-metal.metal b/ggml-metal.metal index 0cb85e1a5..e2796fd60 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1654,8 +1654,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) { // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. static void rope_yarn( float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, - thread float * cos_theta, thread float * sin_theta -) { + thread float * cos_theta, thread float * sin_theta) { // Get n-d rotational scaling corrected for extrapolation float theta_interp = freq_scale * theta_extrap; float theta = theta_interp; @@ -1672,55 +1671,20 @@ static void rope_yarn( // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` -static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) { - return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base)); +static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) { + return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base)); } static void rope_yarn_corr_dims( - int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2] + int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2] ) { // start and end correction dims - dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base))); - dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base))); + dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))); + dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base))); } -typedef void (rope_t)( - device const void * src0, - device const int32_t * src1, - device const float * src2, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int & n_past, - constant int & n_dims, - constant int & mode, - constant int & n_orig_ctx, - constant float & freq_base, - constant float & freq_scale, - constant float & ext_factor, - constant float & attn_factor, - constant float & beta_fast, - constant float & beta_slow, - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg[[threads_per_threadgroup]], - uint3 tgpig[[threadgroup_position_in_grid]]); - template -kernel void kernel_rope( +kernel void kernel_rope_norm( device const void * src0, device const int32_t * src1, device const float * src2, @@ -1743,8 +1707,7 @@ kernel void kernel_rope( constant uint64_t & nb3, constant int & n_past, constant int & n_dims, - constant int & mode, - constant int & n_orig_ctx, + constant int & n_ctx_orig, constant float & freq_base, constant float & freq_scale, constant float & ext_factor, @@ -1758,69 +1721,130 @@ kernel void kernel_rope( const int64_t i2 = tgpig[1]; const int64_t i1 = tgpig[0]; - const bool is_neox = mode & 2; - float corr_dims[2]; - rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); + rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); device const int32_t * pos = src1; - const int64_t p = pos[i2]; - - const float theta_base = (float)p; + const float theta_base = (float) pos[i2]; const float inv_ndims = -1.f/n_dims; - if (!is_neox) { - for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { + float cos_theta; + float sin_theta; + + for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { + if (i0 < n_dims) { + const int64_t ic = i0/2; + const float theta = theta_base * pow(freq_base, inv_ndims*i0); - float cos_theta, sin_theta; - rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - const T x0 = src[0]; - const T x1 = src[1]; + const float x0 = src[0]; + const float x1 = src[1]; dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[1] = x0*sin_theta + x1*cos_theta; - } - } else { - for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) { - if (ic < n_dims) { - const int64_t i0 = ic/2; + } else { + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - const float freq_factor = src2 != src0 ? src2[i0] : 1.0f; - - const float theta = theta_base * pow(freq_base, inv_ndims*ic); - - float cos_theta, sin_theta; - rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta); - - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; - } else { - const int64_t i0 = ic; - - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } + dst_data[0] = src[0]; + dst_data[1] = src[1]; } } } -template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope; -template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope; +template +kernel void kernel_rope_neox( + device const void * src0, + device const int32_t * src1, + device const float * src2, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int & n_past, + constant int & n_dims, + constant int & n_ctx_orig, + constant float & freq_base, + constant float & freq_scale, + constant float & ext_factor, + constant float & attn_factor, + constant float & beta_fast, + constant float & beta_slow, + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg[[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int64_t i3 = tgpig[2]; + const int64_t i2 = tgpig[1]; + const int64_t i1 = tgpig[0]; + + float corr_dims[2]; + rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + + device const int32_t * pos = src1; + + const float theta_base = (float) pos[i2]; + const float inv_ndims = -1.f/n_dims; + + float cos_theta; + float sin_theta; + + for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { + if (i0 < n_dims) { + const int64_t ic = i0/2; + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + } else { + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +typedef decltype(kernel_rope_norm) kernel_rope_norm_t; +typedef decltype(kernel_rope_neox) kernel_rope_neox_t; + +template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm; +template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm; + +template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox; +template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox; typedef void (im2col_t)( device const float * x, diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index 5cd97e4ff..3ff76474d 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -8928,49 +8928,6 @@ static void rope_neox( dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta; } -static void rope_glm_f32( - const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, - int n_ctx -, const sycl::nd_item<3> &item_ct1) { - const int col = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - const int half_n_dims = ncols/4; - - if (col >= half_n_dims) { - return; - } - - const int row = item_ct1.get_local_range(1) * item_ct1.get_group(1) + - item_ct1.get_local_id(1); - const int i = row*ncols + col; - const int i2 = row/p_delta_rows; - - const float col_theta_scale = dpct::pow(freq_base, -2.0f * col / ncols); - // FIXME: this is likely wrong - const int p = pos != nullptr ? pos[i2] : 0; - - const float theta = sycl::min(p, n_ctx - 2) * freq_scale * col_theta_scale; - const float sin_theta = sycl::sin((float)theta); - const float cos_theta = sycl::cos((float)theta); - - const float x0 = x[i + 0]; - const float x1 = x[i + half_n_dims]; - - dst[i + 0] = x0*cos_theta - x1*sin_theta; - dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta; - - const float block_theta = - ((float)sycl::max(p - n_ctx - 2, 0)) * col_theta_scale; - const float sin_block_theta = sycl::sin((float)block_theta); - const float cos_block_theta = sycl::cos((float)block_theta); - - const float x2 = x[i + half_n_dims * 2]; - const float x3 = x[i + half_n_dims * 3]; - - dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta; - dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta; -} - static void k_sum_rows_f32(const float * x, float * dst, const int ncols, const sycl::nd_item<3> &item_ct1) { const int row = item_ct1.get_group(1); @@ -12520,22 +12477,6 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows, } } -static void rope_glm_f32_sycl(const float *x, float *dst, int ncols, int nrows, - const int32_t *pos, float freq_scale, - int p_delta_rows, float freq_base, int n_ctx, - dpct::queue_ptr stream) { - GGML_ASSERT(ncols % 4 == 0); - const sycl::range<3> block_dims(1, 1, SYCL_ROPE_BLOCK_SIZE / 4); - const int num_blocks_x = (ncols + SYCL_ROPE_BLOCK_SIZE - 1) / SYCL_ROPE_BLOCK_SIZE; - const sycl::range<3> block_nums(1, nrows, num_blocks_x); - stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - rope_glm_f32(x, dst, ncols, pos, freq_scale, - p_delta_rows, freq_base, n_ctx, - item_ct1); - }); -} - static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols, const int nrows, dpct::queue_ptr stream) { const sycl::range<3> block_dims(1, 1, WARP_SIZE); @@ -14066,8 +14007,8 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1, //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; - const int n_ctx = ((int32_t *) dst->op_params)[3]; - const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; + //const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; // RoPE alteration for extended context float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; @@ -14087,7 +14028,9 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1, } const bool is_neox = mode & 2; - const bool is_glm = mode & 4; + +#pragma message("TODO: update rope NORM mode to match NEOX mode") +#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634") if (is_neox) { pos = (const int32_t *) src1_dd; @@ -14100,13 +14043,10 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1, } rope_corr_dims corr_dims; - ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v); + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v); // compute - if (is_glm) { - GGML_ASSERT(false); - rope_glm_f32_sycl(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream); - } else if (is_neox) { + if (is_neox) { if (src0->type == GGML_TYPE_F32) { rope_neox_sycl( (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor, diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 5e12ea9dd..e0c512c0d 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -3898,11 +3898,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const { const int mode = ((const int32_t *) dst->op_params)[2]; const bool is_neox = mode & 2; - const bool is_glm = mode & 4; - - if (is_glm) { - return nullptr; - } if (is_neox) { if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { @@ -4401,7 +4396,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; // const int n_ctx = ((int32_t *) dst->op_params)[3]; - const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; const float freq_base = ((float *) dst->op_params)[5]; const float freq_scale = ((float *) dst->op_params)[6]; const float ext_factor = ((float *) dst->op_params)[7]; @@ -4410,12 +4405,12 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con const float beta_slow = ((float *) dst->op_params)[10]; const bool is_neox = mode & 2; - const bool is_glm = mode & 4; - GGML_ASSERT(!is_glm); +#pragma message("TODO: update rope NORM mode to match NEOX mode") +#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634") float corr_dims[2]; - ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); if (is_neox) { const float theta_scale = powf(freq_base, -2.0f/n_dims); @@ -6485,9 +6480,8 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const case GGML_OP_ROPE: { const int mode = ((const int32_t *) op->op_params)[2]; - const bool is_glm = mode & 4; - return !is_glm; + return true; } break; case GGML_OP_NONE: case GGML_OP_RESHAPE: @@ -6992,15 +6986,15 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_ } else if (tensor->op == GGML_OP_ROPE) { const int n_dims = ((int32_t *) tensor->op_params)[1]; const int mode = ((int32_t *) tensor->op_params)[2]; - const int n_ggml_ctx = ((int32_t *) tensor->op_params)[3]; - const int n_orig_ggml_ctx = ((int32_t *) tensor->op_params)[4]; + //const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3]; + const int n_ctx_orig_ggml = ((int32_t *) tensor->op_params)[4]; float freq_base = ((float *) tensor->op_params)[5]; float freq_scale = ((float *) tensor->op_params)[6]; float ext_factor = ((float *) tensor->op_params)[7]; float attn_factor = ((float *) tensor->op_params)[8]; float beta_fast = ((float *) tensor->op_params)[9]; float beta_slow = ((float *) tensor->op_params)[10]; - tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ggml_ctx, n_orig_ggml_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); } else if (tensor->op == GGML_OP_UNARY) { switch (ggml_get_unary_op(tensor)) { case GGML_UNARY_OP_SILU: diff --git a/ggml.c b/ggml.c index 11e5c34ab..1fc77743b 100644 --- a/ggml.c +++ b/ggml.c @@ -6250,16 +6250,13 @@ static struct ggml_tensor * ggml_rope_impl( struct ggml_tensor * c, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow, - float xpos_base, - bool xpos_down, bool inplace) { GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported"); @@ -6280,15 +6277,13 @@ static struct ggml_tensor * ggml_rope_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx }; + int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; memcpy(params + 5, &freq_base, sizeof(float)); memcpy(params + 6, &freq_scale, sizeof(float)); memcpy(params + 7, &ext_factor, sizeof(float)); memcpy(params + 8, &attn_factor, sizeof(float)); memcpy(params + 9, &beta_fast, sizeof(float)); memcpy(params + 10, &beta_slow, sizeof(float)); - memcpy(params + 11, &xpos_base, sizeof(float)); - memcpy(params + 12, &xpos_down, sizeof(bool)); ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_ROPE; @@ -6305,10 +6300,9 @@ struct ggml_tensor * ggml_rope( struct ggml_tensor * a, struct ggml_tensor * b, int n_dims, - int mode, - int n_ctx) { + int mode) { return ggml_rope_impl( - ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false + ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false ); } @@ -6317,10 +6311,9 @@ struct ggml_tensor * ggml_rope_inplace( struct ggml_tensor * a, struct ggml_tensor * b, int n_dims, - int mode, - int n_ctx) { + int mode) { return ggml_rope_impl( - ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true + ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true ); } @@ -6331,8 +6324,7 @@ struct ggml_tensor * ggml_rope_ext( struct ggml_tensor * c, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, @@ -6340,8 +6332,8 @@ struct ggml_tensor * ggml_rope_ext( float beta_fast, float beta_slow) { return ggml_rope_impl( - ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false + ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, false ); } @@ -6352,8 +6344,7 @@ struct ggml_tensor * ggml_rope_ext_inplace( struct ggml_tensor * c, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, @@ -6361,8 +6352,8 @@ struct ggml_tensor * ggml_rope_ext_inplace( float beta_fast, float beta_slow) { return ggml_rope_impl( - ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true + ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, true ); } @@ -6372,8 +6363,7 @@ struct ggml_tensor * ggml_rope_custom( struct ggml_tensor * b, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, @@ -6381,8 +6371,8 @@ struct ggml_tensor * ggml_rope_custom( float beta_fast, float beta_slow) { return ggml_rope_impl( - ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false + ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, false ); } @@ -6392,8 +6382,7 @@ struct ggml_tensor * ggml_rope_custom_inplace( struct ggml_tensor * b, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, @@ -6401,21 +6390,11 @@ struct ggml_tensor * ggml_rope_custom_inplace( float beta_fast, float beta_slow) { return ggml_rope_impl( - ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true + ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, true ); } -struct ggml_tensor * ggml_rope_xpos_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int n_dims, - float base, - bool down) { - return ggml_rope_impl(ctx, a, b, NULL, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true); -} - // ggml_rope_back struct ggml_tensor * ggml_rope_back( @@ -6425,16 +6404,13 @@ struct ggml_tensor * ggml_rope_back( struct ggml_tensor * c, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, - float beta_slow, - float xpos_base, - bool xpos_down) { + float beta_slow) { GGML_ASSERT(ggml_is_vector(b)); GGML_ASSERT(b->type == GGML_TYPE_I32); GGML_ASSERT(a->ne[2] == b->ne[0]); @@ -6450,15 +6426,13 @@ struct ggml_tensor * ggml_rope_back( struct ggml_tensor * result = ggml_dup_tensor(ctx, a); - int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx }; + int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; memcpy(params + 5, &freq_base, sizeof(float)); memcpy(params + 6, &freq_scale, sizeof(float)); memcpy(params + 7, &ext_factor, sizeof(float)); memcpy(params + 8, &attn_factor, sizeof(float)); memcpy(params + 9, &beta_fast, sizeof(float)); memcpy(params + 10, &beta_slow, sizeof(float)); - memcpy(params + 11, &xpos_base, sizeof(float)); - memcpy(params + 12, &xpos_down, sizeof(bool)); ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_ROPE_BACK; @@ -14227,8 +14201,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) { // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. static void rope_yarn( float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, - float * cos_theta, float * sin_theta -) { + float * cos_theta, float * sin_theta) { // Get n-d rotational scaling corrected for extrapolation float theta_interp = freq_scale * theta_extrap; float theta = theta_interp; @@ -14245,18 +14218,19 @@ static void rope_yarn( // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get // `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` -static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) { - return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base)); +static float ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) { + return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float)M_PI)) / (2 * logf(base)); } static void ggml_rope_cache_init( - float theta_base, float freq_scale, float corr_dims[2], int64_t ne0, float ext_factor, float mscale, - float * cache, float sin_sign, float theta_scale -) { + float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale, + float * cache, float sin_sign, float theta_scale) { + // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py float theta = theta_base; for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + const float ff = freq_factors ? freq_factors[i0/2] : 1.0f; rope_yarn( - theta, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1] + theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1] ); cache[i0 + 1] *= sin_sign; @@ -14265,11 +14239,11 @@ static void ggml_rope_cache_init( } GGML_CALL void ggml_rope_yarn_corr_dims( - int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2] + int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2] ) { // start and end correction dims - float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base)); - float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base)); + float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base)); + float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base)); dims[0] = MAX(0, start); dims[1] = MIN(n_dims - 1, end); } @@ -14289,15 +14263,11 @@ static void ggml_compute_forward_rope_f32( float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; - // these two only relevant for xPos RoPE: - float xpos_base; - bool xpos_down; - //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; - const int n_ctx = ((int32_t *) dst->op_params)[3]; - const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; + //const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); @@ -14305,8 +14275,6 @@ static void ggml_compute_forward_rope_f32( memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); - memcpy(&xpos_base, (int32_t *) dst->op_params + 11, sizeof(float)); - memcpy(&xpos_down, (int32_t *) dst->op_params + 12, sizeof(bool)); GGML_TENSOR_UNARY_OP_LOCALS @@ -14336,20 +14304,15 @@ static void ggml_compute_forward_rope_f32( const float theta_scale = powf(freq_base, -2.0f/n_dims); float corr_dims[2]; - ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); const bool is_neox = mode & 2; - const bool is_glm = mode & 4; const float * freq_factors = NULL; - if (is_neox) { - if (src2 != NULL) { - GGML_ASSERT(src2->type == GGML_TYPE_F32); - GGML_ASSERT(src2->ne[0] >= n_dims / 2); - freq_factors = (const float *) src2->data; - } - } else { - GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox"); + if (src2 != NULL) { + GGML_ASSERT(src2->type == GGML_TYPE_F32); + GGML_ASSERT(src2->ne[0] >= n_dims / 2); + freq_factors = (const float *) src2->data; } // backward process uses inverse rotation by cos and sin. @@ -14364,95 +14327,51 @@ static void ggml_compute_forward_rope_f32( const int64_t p = pos[i2]; float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; - if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox - ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); - } + ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; if (ir > ir1) break; - float theta_base = (float)p; - - if (is_glm) { - theta_base = MIN(p, n_ctx - 2); - float block_theta = MAX(p - (n_ctx - 2), 0); - for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { - const float cos_theta = cosf(theta_base); - const float sin_theta = sinf(theta_base) * sin_sign; - const float cos_block_theta = cosf(block_theta); - const float sin_block_theta = sinf(block_theta) * sin_sign; - - theta_base *= theta_scale; - block_theta *= theta_scale; - - const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - const float x2 = src[n_dims]; - const float x3 = src[n_dims/2*3]; - - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; - dst_data[n_dims] = x2*cos_block_theta - x3*sin_block_theta; - dst_data[n_dims/2*3] = x2*sin_block_theta + x3*cos_block_theta; - } - } else if (!is_neox) { - for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + if (!is_neox) { + for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { const float cos_theta = cache[i0 + 0]; const float sin_theta = cache[i0 + 1]; - // zeta scaling for xPos only: - float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f; - if (xpos_down) zeta = 1.0f / zeta; - const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); const float x0 = src[0]; const float x1 = src[1]; - dst_data[0] = x0*cos_theta*zeta - x1*sin_theta*zeta; - dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta; + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[1] = x0*sin_theta + x1*cos_theta; } } else { - // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py - for (int64_t ic = 0; ic < ne0; ic += 2) { - if (ic < n_dims) { - const int64_t i0 = ic/2; + for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { + const int64_t ic = i0/2; - const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f; + const float cos_theta = cache[i0 + 0]; + const float sin_theta = cache[i0 + 1]; - float cos_theta, sin_theta; - rope_yarn( - theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, - &cos_theta, &sin_theta - ); + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - sin_theta *= sin_sign; - theta_base *= theta_scale; + const float x0 = src[0]; + const float x1 = src[n_dims/2]; - const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; - } else { - const int64_t i0 = ic; - - const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; } } + + for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) { + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } } } } @@ -14477,8 +14396,8 @@ static void ggml_compute_forward_rope_f16( //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; - const int n_ctx = ((int32_t *) dst->op_params)[3]; - const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; + //const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); @@ -14514,20 +14433,15 @@ static void ggml_compute_forward_rope_f16( const float theta_scale = powf(freq_base, -2.0f/n_dims); float corr_dims[2]; - ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); const bool is_neox = mode & 2; - const bool is_glm = mode & 4; const float * freq_factors = NULL; - if (is_neox) { - if (src2 != NULL) { - GGML_ASSERT(src2->type == GGML_TYPE_F32); - GGML_ASSERT(src2->ne[0] >= n_dims / 2); - freq_factors = (const float *) src2->data; - } - } else { - GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox"); + if (src2 != NULL) { + GGML_ASSERT(src2->type == GGML_TYPE_F32); + GGML_ASSERT(src2->ne[0] >= n_dims / 2); + freq_factors = (const float *) src2->data; } // backward process uses inverse rotation by cos and sin. @@ -14542,43 +14456,14 @@ static void ggml_compute_forward_rope_f16( const int64_t p = pos[i2]; float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; - if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox - ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); - } + ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; if (ir > ir1) break; - float theta_base = (float)p; - - if (is_glm) { - theta_base = MIN(p, n_ctx - 2); - float block_theta = MAX(p - (n_ctx - 2), 0); - for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { - const float cos_theta = cosf(theta_base); - const float sin_theta = sinf(theta_base) * sin_sign; - const float cos_block_theta = cosf(block_theta); - const float sin_block_theta = sinf(block_theta) * sin_sign; - - theta_base *= theta_scale; - block_theta *= theta_scale; - - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float x0 = GGML_FP16_TO_FP32(src[0]); - const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]); - const float x2 = GGML_FP16_TO_FP32(src[n_dims]); - const float x3 = GGML_FP16_TO_FP32(src[n_dims/2*3]); - - dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); - dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); - dst_data[n_dims] = GGML_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta); - dst_data[n_dims/2*3] = GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta); - } - } else if (!is_neox) { - for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + if (!is_neox) { + for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { const float cos_theta = cache[i0 + 0]; const float sin_theta = cache[i0 + 1]; @@ -14592,41 +14477,30 @@ static void ggml_compute_forward_rope_f16( dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); } } else { - // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py - for (int64_t ic = 0; ic < ne0; ic += 2) { - if (ic < n_dims) { - const int64_t i0 = ic/2; + for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { + const int64_t ic = i0/2; - const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f; + const float cos_theta = cache[i0 + 0]; + const float sin_theta = cache[i0 + 1]; - float cos_theta, sin_theta; - rope_yarn( - theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, - &cos_theta, &sin_theta - ); + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - sin_theta *= sin_sign; - theta_base *= theta_scale; + const float x0 = GGML_FP16_TO_FP32(src[0]); + const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]); - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float x0 = GGML_FP16_TO_FP32(src[0]); - const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]); - - dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); - dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); - } else { - const int64_t i0 = ic; - - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } + dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); + dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); } } + + for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) { + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } } } } @@ -18327,9 +18201,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor //const int n_past = ((int32_t *) tensor->op_params)[0]; const int n_dims = ((int32_t *) tensor->op_params)[1]; const int mode = ((int32_t *) tensor->op_params)[2]; - const int n_ctx = ((int32_t *) tensor->op_params)[3]; - const int n_orig_ctx = ((int32_t *) tensor->op_params)[4]; - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down; + //const int n_ctx = ((int32_t *) tensor->op_params)[3]; + const int n_ctx_orig = ((int32_t *) tensor->op_params)[4]; + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float)); @@ -18337,8 +18211,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float)); memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float)); - memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float)); - memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool)); src0->grad = ggml_add_or_set(ctx, src0->grad, @@ -18348,16 +18220,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src2, n_dims, mode, - n_ctx, - n_orig_ctx, + n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, - beta_slow, - xpos_base, - xpos_down), + beta_slow), zero_table); } } break; @@ -18367,9 +18236,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor //const int n_past = ((int32_t *) tensor->op_params)[0]; const int n_dims = ((int32_t *) tensor->op_params)[1]; const int mode = ((int32_t *) tensor->op_params)[2]; - const int n_ctx = ((int32_t *) tensor->op_params)[3]; - const int n_orig_ctx = ((int32_t *) tensor->op_params)[4]; - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down; + //const int n_ctx = ((int32_t *) tensor->op_params)[3]; + const int n_ctx_orig = ((int32_t *) tensor->op_params)[4]; + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float)); @@ -18377,8 +18246,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float)); memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float)); - memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float)); - memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool)); src0->grad = ggml_add_or_set(ctx, src0->grad, @@ -18388,16 +18255,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src2, n_dims, mode, - n_ctx, - n_orig_ctx, + n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, - xpos_base, - xpos_down, false), zero_table); } diff --git a/ggml.h b/ggml.h index addcf1bfe..13502a362 100644 --- a/ggml.h +++ b/ggml.h @@ -1465,7 +1465,6 @@ extern "C" { // rotary position embedding // if mode & 1 == 1, skip n_past elements (NOT SUPPORTED) // if mode & 2 == 1, GPT-NeoX style - // if mode & 4 == 1, ChatGLM style // // b is an int32 vector with size a->ne[2], it contains the positions // c is freq factors (e.g. phi3-128k), (optional) @@ -1474,8 +1473,7 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b, int n_dims, - int mode, - int n_ctx); + int mode); // in-place, returns view(a) GGML_API struct ggml_tensor * ggml_rope_inplace( @@ -1483,8 +1481,7 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b, int n_dims, - int mode, - int n_ctx); + int mode); // custom RoPE GGML_API struct ggml_tensor * ggml_rope_ext( @@ -1494,8 +1491,7 @@ extern "C" { struct ggml_tensor * c, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, @@ -1511,8 +1507,7 @@ extern "C" { struct ggml_tensor * c, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, @@ -1526,8 +1521,7 @@ extern "C" { struct ggml_tensor * b, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, @@ -1542,8 +1536,7 @@ extern "C" { struct ggml_tensor * b, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, @@ -1552,17 +1545,9 @@ extern "C" { float beta_slow), "use ggml_rope_ext_inplace instead"); - struct ggml_tensor * ggml_rope_xpos_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int n_dims, - float base, - bool down); - // compute correction dims for YaRN RoPE scaling GGML_CALL void ggml_rope_yarn_corr_dims( - int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]); + int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]); // rotary position embedding backward, i.e compute dx from dy // a - dy @@ -1573,16 +1558,13 @@ extern "C" { struct ggml_tensor * c, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, - float beta_slow, - float xpos_base, - bool xpos_down); + float beta_slow); // clamp // in-place, returns view(a) diff --git a/kompute-shaders/op_rope_f16.comp b/kompute-shaders/op_rope_f16.comp index b44622584..1a4058b3f 100644 --- a/kompute-shaders/op_rope_f16.comp +++ b/kompute-shaders/op_rope_f16.comp @@ -14,7 +14,7 @@ void main() { const bool is_neox = (pcs.mode & 2) != 0; float corr_dims[2]; - rope_yarn_corr_dims(pcs.n_dims, pcs.n_orig_ctx, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims); + rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims); const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims); diff --git a/kompute-shaders/op_rope_f32.comp b/kompute-shaders/op_rope_f32.comp index 2c0235d75..65e03827a 100644 --- a/kompute-shaders/op_rope_f32.comp +++ b/kompute-shaders/op_rope_f32.comp @@ -14,7 +14,7 @@ void main() { const bool is_neox = (pcs.mode & 2) != 0; float corr_dims[2]; - rope_yarn_corr_dims(pcs.n_dims, pcs.n_orig_ctx, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims); + rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims); const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims); diff --git a/kompute-shaders/rope_common.comp b/kompute-shaders/rope_common.comp index 57ba6597a..7b9394cb2 100644 --- a/kompute-shaders/rope_common.comp +++ b/kompute-shaders/rope_common.comp @@ -9,7 +9,7 @@ layout (push_constant) uniform parameter { uint outOff; int n_dims; int mode; - int n_orig_ctx; + int n_ctx_orig; float freq_base; float freq_scale; float ext_factor; @@ -54,14 +54,14 @@ void rope_yarn( // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` -float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) { - return n_dims * log(n_orig_ctx / (n_rot * TWOPI_F)) / (2 * log(base)); +float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) { + return n_dims * log(n_ctx_orig / (n_rot * TWOPI_F)) / (2 * log(base)); } void rope_yarn_corr_dims( - int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, out float dims[2] + int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, out float dims[2] ) { // start and end correction dims - dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base))); - dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base))); + dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))); + dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base))); } diff --git a/llama.cpp b/llama.cpp index 06889126e..414d390e8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1848,7 +1848,7 @@ struct llama_hparams { float rope_attn_factor = 1.0f; float rope_freq_base_train; float rope_freq_scale_train; - uint32_t n_yarn_orig_ctx; + uint32_t n_ctx_orig_yarn; float rope_yarn_log_mul; // for State Space Models @@ -1890,7 +1890,7 @@ struct llama_hparams { if (this->n_expert_shared != other.n_expert_shared) return true; if (this->rope_finetuned != other.rope_finetuned) return true; - if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true; + if (this->n_ctx_orig_yarn != other.n_ctx_orig_yarn) return true; if (this->ssm_d_conv != other.ssm_d_conv) return true; if (this->ssm_d_inner != other.ssm_d_inner) return true; @@ -1949,7 +1949,7 @@ struct llama_cparams { float rope_freq_base; float rope_freq_scale; - uint32_t n_yarn_orig_ctx; + uint32_t n_ctx_orig_yarn; // These hyperparameters are not exposed in GGUF, because all // existing YaRN models use the same values for them. float yarn_ext_factor; @@ -4005,8 +4005,8 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); hparams.rope_finetuned = rope_finetuned; - hparams.n_yarn_orig_ctx = hparams.n_ctx_train; - ml.get_key(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_yarn_orig_ctx, false); + hparams.n_ctx_orig_yarn = hparams.n_ctx_train; + ml.get_key(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_ctx_orig_yarn, false); // rope_freq_base (optional) hparams.rope_freq_base_train = 10000.0f; @@ -4968,7 +4968,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type); LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); - LLAMA_LOG_INFO("%s: n_yarn_orig_ctx = %u\n", __func__, hparams.n_yarn_orig_ctx); + LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); @@ -7134,7 +7134,7 @@ struct llm_build_context { const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size) const int32_t n_outputs; const int32_t kv_head; // index of where we store new KV data in the cache - const int32_t n_orig_ctx; + const int32_t n_ctx_orig; const bool flash_attn; @@ -7183,7 +7183,7 @@ struct llm_build_context { n_kv (worst_case ? kv_self.size : kv_self.n), n_outputs (worst_case ? n_tokens : lctx.n_outputs), kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), - n_orig_ctx (cparams.n_yarn_orig_ctx), + n_ctx_orig (cparams.n_ctx_orig_yarn), flash_attn (cparams.flash_attn), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), @@ -7241,7 +7241,7 @@ struct llm_build_context { ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k), ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), 0), - lctx.inp_K_shift, rope_factors, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(tmp, "K_shifted", il); @@ -7350,7 +7350,7 @@ struct llm_build_context { // choose long/short freq factors based on the context size const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max; - if (n_ctx_pre_seq > hparams.n_yarn_orig_ctx) { + if (n_ctx_pre_seq > hparams.n_ctx_orig_yarn) { return model.layers[il].rope_long; } @@ -7466,14 +7466,14 @@ struct llm_build_context { Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); Kcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -7597,12 +7597,12 @@ struct llm_build_context { case MODEL_7B: Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); Kcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); break; @@ -7709,14 +7709,14 @@ struct llm_build_context { Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); Kcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -7829,13 +7829,13 @@ struct llm_build_context { // using mode = 2 for neox mode Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, + ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, + ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -7953,14 +7953,14 @@ struct llm_build_context { Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); Kcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -8106,14 +8106,14 @@ struct llm_build_context { Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); Kcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -8460,14 +8460,14 @@ struct llm_build_context { Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); Kcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -8900,14 +8900,14 @@ struct llm_build_context { Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -9019,13 +9019,13 @@ struct llm_build_context { // using mode = 2 for neox mode Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, + ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, + ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -9131,14 +9131,14 @@ struct llm_build_context { Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); Kcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -9245,14 +9245,14 @@ struct llm_build_context { Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); Kcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -9397,7 +9397,7 @@ struct llm_build_context { Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, + ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); @@ -9408,7 +9408,7 @@ struct llm_build_context { cb(Qcur, "Qcur", il); Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, + ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -9519,7 +9519,7 @@ struct llm_build_context { Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, 0, n_orig_ctx, + ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); @@ -9528,7 +9528,7 @@ struct llm_build_context { cb(Qcur, "Qcur", il); Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, 0, n_orig_ctx, + ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -9636,13 +9636,13 @@ struct llm_build_context { Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens), inp_pos, nullptr, - n_embd_head, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_embd_head, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(Qcur, "Qcur", il); Kcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos, nullptr, - n_embd_head, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_embd_head, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); @@ -9844,14 +9844,14 @@ struct llm_build_context { struct ggml_tensor * Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); struct ggml_tensor * Kcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -9960,14 +9960,14 @@ struct llm_build_context { Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); Kcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -10077,14 +10077,14 @@ struct llm_build_context { Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); Kcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -10207,14 +10207,14 @@ struct llm_build_context { Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); Kcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -10327,7 +10327,7 @@ struct llm_build_context { Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr, - n_embd_head_k, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(Qcur, "Qcur", il); @@ -10336,7 +10336,7 @@ struct llm_build_context { Kcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr, - n_embd_head_k, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); @@ -10447,14 +10447,14 @@ struct llm_build_context { Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); Kcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -10737,14 +10737,14 @@ struct llm_build_context { Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); Kcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -10868,14 +10868,14 @@ struct llm_build_context { Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); Kcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -10982,14 +10982,14 @@ struct llm_build_context { Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); Kcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -11117,14 +11117,14 @@ struct llm_build_context { Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); Kcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -11334,7 +11334,7 @@ struct llm_build_context { q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE q_pe = ggml_rope_ext( ctx0, q_pe, inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor_scaled, beta_fast, beta_slow ); cb(q_pe, "q_pe", il); @@ -11343,7 +11343,7 @@ struct llm_build_context { k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE k_pe = ggml_rope_ext( ctx0, k_pe, inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor_scaled, beta_fast, beta_slow ); cb(k_pe, "k_pe", il); @@ -16067,8 +16067,8 @@ struct llama_context * llama_new_context_with_model( cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); - cparams.n_yarn_orig_ctx = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx : - hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx : + cparams.n_ctx_orig_yarn = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx : + hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn : hparams.n_ctx_train; cparams.cb_eval = params.cb_eval; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 8dc90a45d..ce406a8af 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1141,7 +1141,7 @@ struct test_rope : public test_case { const std::array ne_a; int n_dims; int mode; - int n_ctx; + int n_ctx; // used to generate positions float fs; // freq_scale float ef; // ext_factor float af; // attn_factor @@ -1168,7 +1168,7 @@ struct test_rope : public test_case { } ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]); ggml_tensor * freq = ff ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_dims/2) : nullptr; - ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, n_ctx, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); return out; } @@ -1615,7 +1615,7 @@ struct llama_hparams { // cparams static constexpr uint32_t n_ctx = 512; // user-specified context size - static constexpr uint32_t n_orig_ctx = n_ctx; + static constexpr uint32_t n_ctx_orig = n_ctx; // batch int32_t n_tokens; @@ -1806,13 +1806,13 @@ struct test_llama : public test_llm { Qcur = ggml_rope_ext( ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head, hp.n_tokens), inp_pos, nullptr, - hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale, + hp.n_rot, 0, hp.n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); Kcur = ggml_rope_ext( ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos, nullptr, - hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale, + hp.n_rot, 0, hp.n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -1931,12 +1931,12 @@ struct test_falcon : public test_llm { // using mode = 2 for neox mode Qcur = ggml_rope_ext( - ctx, Qcur, inp_pos, nullptr, hp.n_rot, 2, 0, hp.n_orig_ctx, + ctx, Qcur, inp_pos, nullptr, hp.n_rot, 2, hp.n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); Kcur = ggml_rope_ext( - ctx, Kcur, inp_pos, nullptr, hp.n_rot, 2, 0, hp.n_orig_ctx, + ctx, Kcur, inp_pos, nullptr, hp.n_rot, 2, hp.n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -2236,15 +2236,15 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op for (float ef : { 0.0f, 0.7465f }) { for (float af : { 1.0f, 1.4245f }) { for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { - // TODO: ff not supported yet for !neox - test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 7B - if (all) { - test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 13B - test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 30B - test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 65B - } - for (bool ff : {false, true}) { // freq_factors + test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 7B + + if (all) { + test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 13B + test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 30B + test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 65B + } + if (all) { test_cases.emplace_back(new test_rope(type, { 64, 1, 10, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B) test_cases.emplace_back(new test_rope(type, { 64, 71, 10, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B) @@ -2256,6 +2256,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B) } } + all = false; } } diff --git a/tests/test-grad0.cpp b/tests/test-grad0.cpp index 21ca43be3..a35327645 100644 --- a/tests/test-grad0.cpp +++ b/tests/test-grad0.cpp @@ -1465,7 +1465,7 @@ int main(int argc, const char ** argv) { continue; } - struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode, 0)); + struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode)); GGML_PRINT_DEBUG("rope f32: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode); check_gradient("rope f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY); @@ -1505,7 +1505,7 @@ int main(int argc, const char ** argv) { continue; } - struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode, 0)); + struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode)); GGML_PRINT_DEBUG("rope f16: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode); check_gradient("rope f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY); diff --git a/tests/test-rope.cpp b/tests/test-rope.cpp index 26c1f42dc..f0895ffaa 100644 --- a/tests/test-rope.cpp +++ b/tests/test-rope.cpp @@ -162,12 +162,12 @@ int main(int /*argc*/, const char ** /*argv*/) { x = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); // 100, 101, 102, ..., 172 - struct ggml_tensor * r0 = ggml_rope(ctx0, x, p0, n_rot, mode, 1024); + struct ggml_tensor * r0 = ggml_rope(ctx0, x, p0, n_rot, mode); // -67, -67, -67, ..., -67 - struct ggml_tensor * r1 = ggml_rope(ctx0, r0, p1, n_rot, mode, 1024); // "context swap", i.e. forget n_past_0 - n_past_2 tokens + struct ggml_tensor * r1 = ggml_rope(ctx0, r0, p1, n_rot, mode); // "context swap", i.e. forget n_past_0 - n_past_2 tokens // 33, 34, 35, ..., 105 - struct ggml_tensor * r2 = ggml_rope(ctx0, x, p2, n_rot, mode, 1024); + struct ggml_tensor * r2 = ggml_rope(ctx0, x, p2, n_rot, mode); ggml_cgraph * gf = ggml_new_graph(ctx0);