From 208cd52f7d2ca3eb9708cfd457dde0592ed0e38b Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Wed, 15 Nov 2023 17:58:19 -0500 Subject: [PATCH] vulkan : implement YaRN RoPE scaling (#2268) The NeoX cur_rot part is different because I'm pretty sure my original implementation was wrong. --- ggml-vulkan.cpp | 36 ++++++++++++------- kompute/common.comp | 1 + kompute/op_rope_f16.comp | 40 +++++++-------------- kompute/op_rope_f32.comp | 40 +++++++-------------- kompute/rope_common.comp | 75 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 123 insertions(+), 69 deletions(-) create mode 100644 kompute/rope_common.comp diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 8c048c77d..a4f9ade0e 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -1195,8 +1195,8 @@ void ggml_vk_rope( const std::shared_ptr& inB, const std::shared_ptr& out, uint32_t inAOff, uint32_t inBOff, uint32_t outOff, - ggml_type src0t, int32_t n_dims, int32_t mode, - float freq_base, float freq_scale, + ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_orig_ctx, + float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow, int32_t ne01, int32_t ne02, int32_t ne03, uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03, int32_t ne0, @@ -1224,15 +1224,15 @@ void ggml_vk_rope( struct PushConstants { uint32_t inAOff, inBOff, outOff; - int32_t n_dims, mode; - float freq_base, freq_scale; + int32_t n_dims, mode, n_orig_ctx; + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; uint32_t nb00, nb01, nb02, nb03; int32_t ne0; uint32_t nb0, nb1, nb2, nb3; } pushConsts { safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size), - n_dims, mode, - freq_base, freq_scale, + n_dims, mode, n_orig_ctx, + freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3 @@ -1545,13 +1545,23 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph GGML_ASSERT(ne10 == ne02); GGML_ASSERT(src0t == dstt); // const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; - float freq_base; - float freq_scale; - memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); - ggml_vk_rope(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, freq_base, freq_scale, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3); + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + // skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan + const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; + + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + ggml_vk_rope( + seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_orig_ctx, + freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, + ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3 + ); } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/kompute/common.comp b/kompute/common.comp index 040b87375..fe0bc5d15 100644 --- a/kompute/common.comp +++ b/kompute/common.comp @@ -20,6 +20,7 @@ #define GELU_COEF_A 0.044715 #define SQRT_2_OVER_PI 0.79788456080286535587989211986876 +#define TWOPI_F 6.283185307179586f #define QK_K 256 diff --git a/kompute/op_rope_f16.comp b/kompute/op_rope_f16.comp index fd3943c81..e4b5ccca3 100644 --- a/kompute/op_rope_f16.comp +++ b/kompute/op_rope_f16.comp @@ -8,50 +8,32 @@ #version 450 -#include "common.comp" - -// TODO: use a local size of 32 or more (Metal uses 1024) -layout(local_size_x = 1) in; +#include "rope_common.comp" layout(binding = 0) buffer restrict readonly tensorInA { float16_t inA[]; }; layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; }; layout(binding = 2) buffer restrict writeonly tensorOut { float16_t out_[]; }; -layout (push_constant) uniform parameter { - uint inAOff; - uint inBOff; - uint outOff; - int n_dims; - int mode; - float freq_base; - float freq_scale; - uint nb00; - uint nb01; - uint nb02; - uint nb03; - int ne0; - uint nb0; - uint nb1; - uint nb2; - uint nb3; -} pcs; - void main() { const uint i3 = gl_WorkGroupID.z; const uint i2 = gl_WorkGroupID.y; const uint i1 = gl_WorkGroupID.x; const bool is_neox = (pcs.mode & 2) != 0; + + float corr_dims[2]; + rope_yarn_corr_dims(pcs.n_dims, pcs.n_orig_ctx, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims); + const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims); const int p = inB[pcs.inBOff + i2]; - float theta = pcs.freq_scale * float(p); + float theta = float(p); if (!is_neox) { for (uint i0 = 0; i0 < pcs.ne0; i0 += 2) { - const float cos_theta = cos(theta); - const float sin_theta = sin(theta); + float cos_theta, sin_theta; + rope_yarn(theta, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta); theta *= theta_scale; @@ -68,8 +50,10 @@ void main() { const float inv_ndims = -1.f/pcs.n_dims; for (uint ib = 0; ib < pcs.ne0/pcs.n_dims; ++ib) { for (uint ic = 0; ic < pcs.n_dims; ic += 2) { - const float cos_theta = cos(theta); - const float sin_theta = sin(theta); + const uint cur_rot = ib * pcs.n_dims + ic; + + float cos_theta, sin_theta; + rope_yarn(theta, pcs.freq_scale, corr_dims, cur_rot, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta); theta *= theta_scale; diff --git a/kompute/op_rope_f32.comp b/kompute/op_rope_f32.comp index 0cf83fec0..0a882879d 100644 --- a/kompute/op_rope_f32.comp +++ b/kompute/op_rope_f32.comp @@ -8,50 +8,32 @@ #version 450 -#include "common.comp" - -// TODO: use a local size of 32 or more (Metal uses 1024) -layout(local_size_x = 1) in; +#include "rope_common.comp" layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; }; layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; }; layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; }; -layout (push_constant) uniform parameter { - uint inAOff; - uint inBOff; - uint outOff; - int n_dims; - int mode; - float freq_base; - float freq_scale; - uint nb00; - uint nb01; - uint nb02; - uint nb03; - int ne0; - uint nb0; - uint nb1; - uint nb2; - uint nb3; -} pcs; - void main() { const uint i3 = gl_WorkGroupID.z; const uint i2 = gl_WorkGroupID.y; const uint i1 = gl_WorkGroupID.x; const bool is_neox = (pcs.mode & 2) != 0; + + float corr_dims[2]; + rope_yarn_corr_dims(pcs.n_dims, pcs.n_orig_ctx, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims); + const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims); const int p = inB[pcs.inBOff + i2]; - float theta = pcs.freq_scale * float(p); + float theta = float(p); if (!is_neox) { for (uint i0 = 0; i0 < pcs.ne0; i0 += 2) { - const float cos_theta = cos(theta); - const float sin_theta = sin(theta); + float cos_theta, sin_theta; + rope_yarn(theta, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta); theta *= theta_scale; @@ -68,8 +50,10 @@ void main() { const float inv_ndims = -1.f/pcs.n_dims; for (uint ib = 0; ib < pcs.ne0/pcs.n_dims; ++ib) { for (uint ic = 0; ic < pcs.n_dims; ic += 2) { - const float cos_theta = cos(theta); - const float sin_theta = sin(theta); + const uint cur_rot = ib * pcs.n_dims + ic; + + float cos_theta, sin_theta; + rope_yarn(theta, pcs.freq_scale, corr_dims, cur_rot, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta); theta *= theta_scale; diff --git a/kompute/rope_common.comp b/kompute/rope_common.comp new file mode 100644 index 000000000..45682dc28 --- /dev/null +++ b/kompute/rope_common.comp @@ -0,0 +1,75 @@ +/** + * Copyright (c) 2023 Nomic, Inc. All rights reserved. + * + * This software is licensed under the terms of the Software for Open Models License (SOM), + * version 1.0, as detailed in the LICENSE_SOM.txt file. A copy of this license should accompany + * this software. Except as expressly granted in the SOM license, all rights are reserved by Nomic, Inc. + */ + +#include "common.comp" + +// TODO: use a local size of 32 or more (Metal uses 1024) +layout(local_size_x = 1) in; + +layout (push_constant) uniform parameter { + uint inAOff; + uint inBOff; + uint outOff; + int n_dims; + int mode; + int n_orig_ctx; + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; + uint nb00; + uint nb01; + uint nb02; + uint nb03; + int ne0; + uint nb0; + uint nb1; + uint nb2; + uint nb3; +} pcs; + +float rope_yarn_ramp(const float low, const float high, const float i0) { + const float y = (i0 / 2 - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn +// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. +void rope_yarn( + float theta_extrap, float freq_scale, float corr_dims[2], float i0, float ext_factor, float mscale, + out float cos_theta, out float sin_theta +) { + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta = theta_interp; + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * log(1.0f / freq_scale); + } + cos_theta = cos(theta) * mscale; + sin_theta = sin(theta) * mscale; +} + +// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get +// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` +float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) { + return n_dims * log(n_orig_ctx / (n_rot * TWOPI_F)) / (2 * log(base)); +} + +void rope_yarn_corr_dims( + int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, out float dims[2] +) { + // start and end correction dims + dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base))); + dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base))); +}