#include "common.comp" // TODO: use a local size of 32 or more (Metal uses 1024) layout(local_size_x = 1) in; layout (push_constant) uniform parameter { uint inAOff; uint inBOff; uint outOff; int n_dims; int mode; int n_orig_ctx; float freq_base; float freq_scale; float ext_factor; float attn_factor; float beta_fast; float beta_slow; uint nb00; uint nb01; uint nb02; uint nb03; int ne0; uint nb0; uint nb1; uint nb2; uint nb3; } pcs; float rope_yarn_ramp(const float low, const float high, const float i0) { const float y = (i0 / 2 - low) / max(0.001f, high - low); return 1.0f - min(1.0f, max(0.0f, y)); } // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. void rope_yarn( float theta_extrap, float freq_scale, float corr_dims[2], float i0, float ext_factor, float mscale, out float cos_theta, out float sin_theta ) { // Get n-d rotational scaling corrected for extrapolation float theta_interp = freq_scale * theta_extrap; float theta = theta_interp; if (ext_factor != 0.0f) { float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; // Get n-d magnitude scaling corrected for interpolation mscale *= 1.0f + 0.1f * log(1.0f / freq_scale); } cos_theta = cos(theta) * mscale; sin_theta = sin(theta) * mscale; } // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) { return n_dims * log(n_orig_ctx / (n_rot * TWOPI_F)) / (2 * log(base)); } void rope_yarn_corr_dims( int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, out float dims[2] ) { // start and end correction dims dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base))); dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base))); }