llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp

72 lines
2.3 KiB
Plaintext
Raw Normal View History

#include "common.comp"
ggml : move rope type enum to ggml.h (#8949) * ggml : move rope type enum to ggml.h This commit moves the `llama_rope_type` enum from `llama.h` to `ggml.h` and changes its name to `ggml_rope_type`. The motivation for this change is to address the TODO in `llama.h` and use the enum in ggml. Note: This commit does not change the `mode` parameter to be of type `enum ggml_rope_type`. The name `mode` and its usage suggest that it might be more generic and possibly used as a bit field for multiple flags. Further investigation/discussion may be needed to determine if `mode` should be restricted to RoPE types. * squash! ggml : move rope type enum to ggml.h This commit removes GGML_ROPE_TYPE_NONE and GGML_ROPE_TYPE_GLM from ggml.h, and back the llama_rope_type enum. I've kept the assert for GGML_ROPE_TYPE_GLM as I'm not sure if it is safe to remove it yet. * squash! ggml : move rope type enum to ggml.h This commit removes the enum ggml_rope_type from ggml.h and replaces it with a define (GGML_ROPE_TYPE_NEOX). This define is used in the code to check if the mode is set to GPT-NeoX. Also the enum llama_rope_type has been updated to reflect this change. * squash! ggml : move rope type enum to ggml.h This commit contains a suggestion enable the GGML_ROPE_TYPE_NEOX macro/define to be passed to the shader compiler. * squash! ggml : move rope type enum to ggml.h This commit fixes the editorconfig-checker warnings. * squash! ggml : move rope type enum to ggml.h Update comment for ggml_rope function. * Revert "squash! ggml : move rope type enum to ggml.h" This reverts commit 6261222bd0dc0efd51f0fb0435ad3f16a5b52fd6. * squash! ggml : move rope type enum to ggml.h Add GGML_ROPE_TYPE_NEOX to rope_common.comp. * remove extra line --------- Co-authored-by: slaren <slarengh@gmail.com>
2024-08-13 19:13:15 +00:00
#define GGML_ROPE_TYPE_NEOX 2
// TODO: use a local size of 32 or more (Metal uses 1024)
layout(local_size_x = 1) in;
layout (push_constant) uniform parameter {
uint inAOff;
uint inBOff;
uint inCOff;
uint outOff;
int n_dims;
int mode;
int n_ctx_orig;
float freq_base;
float freq_scale;
bool has_freq_factors;
float ext_factor;
float attn_factor;
float beta_fast;
float beta_slow;
uint nb00;
uint nb01;
uint nb02;
uint nb03;
int ne0;
uint nb0;
uint nb1;
uint nb2;
uint nb3;
} pcs;
float rope_yarn_ramp(const float low, const float high, const float i0) {
const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
}
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
void rope_yarn(
float theta_extrap, float freq_scale, float corr_dims[2], float i0, float ext_factor, float mscale,
out float cos_theta, out float sin_theta
) {
// Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap;
float theta = theta_interp;
if (ext_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
}
cos_theta = cos(theta) * mscale;
sin_theta = sin(theta) * mscale;
}
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
return n_dims * log(n_ctx_orig / (n_rot * TWOPI_F)) / (2 * log(base));
}
void rope_yarn_corr_dims(
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, out float dims[2]
) {
// start and end correction dims
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
}