mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-02 14:54:35 +00:00
vulkan : implement YaRN RoPE scaling (#2268)
The NeoX cur_rot part is different because I'm pretty sure my original implementation was wrong.
This commit is contained in:
parent
1829f1d7be
commit
208cd52f7d
@ -1195,8 +1195,8 @@ void ggml_vk_rope(
|
|||||||
const std::shared_ptr<kp::Tensor>& inB,
|
const std::shared_ptr<kp::Tensor>& inB,
|
||||||
const std::shared_ptr<kp::Tensor>& out,
|
const std::shared_ptr<kp::Tensor>& out,
|
||||||
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
|
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
|
||||||
ggml_type src0t, int32_t n_dims, int32_t mode,
|
ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_orig_ctx,
|
||||||
float freq_base, float freq_scale,
|
float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
|
||||||
int32_t ne01, int32_t ne02, int32_t ne03,
|
int32_t ne01, int32_t ne02, int32_t ne03,
|
||||||
uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
|
uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
|
||||||
int32_t ne0,
|
int32_t ne0,
|
||||||
@ -1224,15 +1224,15 @@ void ggml_vk_rope(
|
|||||||
|
|
||||||
struct PushConstants {
|
struct PushConstants {
|
||||||
uint32_t inAOff, inBOff, outOff;
|
uint32_t inAOff, inBOff, outOff;
|
||||||
int32_t n_dims, mode;
|
int32_t n_dims, mode, n_orig_ctx;
|
||||||
float freq_base, freq_scale;
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||||
uint32_t nb00, nb01, nb02, nb03;
|
uint32_t nb00, nb01, nb02, nb03;
|
||||||
int32_t ne0;
|
int32_t ne0;
|
||||||
uint32_t nb0, nb1, nb2, nb3;
|
uint32_t nb0, nb1, nb2, nb3;
|
||||||
} pushConsts {
|
} pushConsts {
|
||||||
safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size),
|
safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size),
|
||||||
n_dims, mode,
|
n_dims, mode, n_orig_ctx,
|
||||||
freq_base, freq_scale,
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
|
||||||
nb00, nb01, nb02, nb03,
|
nb00, nb01, nb02, nb03,
|
||||||
ne0,
|
ne0,
|
||||||
nb0, nb1, nb2, nb3
|
nb0, nb1, nb2, nb3
|
||||||
@ -1547,11 +1547,21 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
|
|||||||
// const int n_past = ((int32_t *) dst->op_params)[0];
|
// const int n_past = ((int32_t *) dst->op_params)[0];
|
||||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||||
const int mode = ((int32_t *) dst->op_params)[2];
|
const int mode = ((int32_t *) dst->op_params)[2];
|
||||||
float freq_base;
|
// skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
|
||||||
float freq_scale;
|
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
||||||
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
|
|
||||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||||
ggml_vk_rope(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, freq_base, freq_scale, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3);
|
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||||
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||||
|
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||||
|
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||||
|
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||||
|
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||||
|
ggml_vk_rope(
|
||||||
|
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_orig_ctx,
|
||||||
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
|
||||||
|
ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
|
||||||
|
);
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_DUP:
|
case GGML_OP_DUP:
|
||||||
case GGML_OP_CPY:
|
case GGML_OP_CPY:
|
||||||
|
@ -20,6 +20,7 @@
|
|||||||
|
|
||||||
#define GELU_COEF_A 0.044715
|
#define GELU_COEF_A 0.044715
|
||||||
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876
|
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876
|
||||||
|
#define TWOPI_F 6.283185307179586f
|
||||||
|
|
||||||
#define QK_K 256
|
#define QK_K 256
|
||||||
|
|
||||||
|
@ -8,50 +8,32 @@
|
|||||||
|
|
||||||
#version 450
|
#version 450
|
||||||
|
|
||||||
#include "common.comp"
|
#include "rope_common.comp"
|
||||||
|
|
||||||
// TODO: use a local size of 32 or more (Metal uses 1024)
|
|
||||||
layout(local_size_x = 1) in;
|
|
||||||
|
|
||||||
layout(binding = 0) buffer restrict readonly tensorInA { float16_t inA[]; };
|
layout(binding = 0) buffer restrict readonly tensorInA { float16_t inA[]; };
|
||||||
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
|
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
|
||||||
layout(binding = 2) buffer restrict writeonly tensorOut { float16_t out_[]; };
|
layout(binding = 2) buffer restrict writeonly tensorOut { float16_t out_[]; };
|
||||||
|
|
||||||
layout (push_constant) uniform parameter {
|
|
||||||
uint inAOff;
|
|
||||||
uint inBOff;
|
|
||||||
uint outOff;
|
|
||||||
int n_dims;
|
|
||||||
int mode;
|
|
||||||
float freq_base;
|
|
||||||
float freq_scale;
|
|
||||||
uint nb00;
|
|
||||||
uint nb01;
|
|
||||||
uint nb02;
|
|
||||||
uint nb03;
|
|
||||||
int ne0;
|
|
||||||
uint nb0;
|
|
||||||
uint nb1;
|
|
||||||
uint nb2;
|
|
||||||
uint nb3;
|
|
||||||
} pcs;
|
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
const uint i3 = gl_WorkGroupID.z;
|
const uint i3 = gl_WorkGroupID.z;
|
||||||
const uint i2 = gl_WorkGroupID.y;
|
const uint i2 = gl_WorkGroupID.y;
|
||||||
const uint i1 = gl_WorkGroupID.x;
|
const uint i1 = gl_WorkGroupID.x;
|
||||||
|
|
||||||
const bool is_neox = (pcs.mode & 2) != 0;
|
const bool is_neox = (pcs.mode & 2) != 0;
|
||||||
|
|
||||||
|
float corr_dims[2];
|
||||||
|
rope_yarn_corr_dims(pcs.n_dims, pcs.n_orig_ctx, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
|
||||||
|
|
||||||
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
|
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
|
||||||
|
|
||||||
const int p = inB[pcs.inBOff + i2];
|
const int p = inB[pcs.inBOff + i2];
|
||||||
|
|
||||||
float theta = pcs.freq_scale * float(p);
|
float theta = float(p);
|
||||||
|
|
||||||
if (!is_neox) {
|
if (!is_neox) {
|
||||||
for (uint i0 = 0; i0 < pcs.ne0; i0 += 2) {
|
for (uint i0 = 0; i0 < pcs.ne0; i0 += 2) {
|
||||||
const float cos_theta = cos(theta);
|
float cos_theta, sin_theta;
|
||||||
const float sin_theta = sin(theta);
|
rope_yarn(theta, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
|
||||||
|
|
||||||
theta *= theta_scale;
|
theta *= theta_scale;
|
||||||
|
|
||||||
@ -68,8 +50,10 @@ void main() {
|
|||||||
const float inv_ndims = -1.f/pcs.n_dims;
|
const float inv_ndims = -1.f/pcs.n_dims;
|
||||||
for (uint ib = 0; ib < pcs.ne0/pcs.n_dims; ++ib) {
|
for (uint ib = 0; ib < pcs.ne0/pcs.n_dims; ++ib) {
|
||||||
for (uint ic = 0; ic < pcs.n_dims; ic += 2) {
|
for (uint ic = 0; ic < pcs.n_dims; ic += 2) {
|
||||||
const float cos_theta = cos(theta);
|
const uint cur_rot = ib * pcs.n_dims + ic;
|
||||||
const float sin_theta = sin(theta);
|
|
||||||
|
float cos_theta, sin_theta;
|
||||||
|
rope_yarn(theta, pcs.freq_scale, corr_dims, cur_rot, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
|
||||||
|
|
||||||
theta *= theta_scale;
|
theta *= theta_scale;
|
||||||
|
|
||||||
|
@ -8,50 +8,32 @@
|
|||||||
|
|
||||||
#version 450
|
#version 450
|
||||||
|
|
||||||
#include "common.comp"
|
#include "rope_common.comp"
|
||||||
|
|
||||||
// TODO: use a local size of 32 or more (Metal uses 1024)
|
|
||||||
layout(local_size_x = 1) in;
|
|
||||||
|
|
||||||
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
|
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
|
||||||
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
|
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
|
||||||
layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
|
layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
|
||||||
|
|
||||||
layout (push_constant) uniform parameter {
|
|
||||||
uint inAOff;
|
|
||||||
uint inBOff;
|
|
||||||
uint outOff;
|
|
||||||
int n_dims;
|
|
||||||
int mode;
|
|
||||||
float freq_base;
|
|
||||||
float freq_scale;
|
|
||||||
uint nb00;
|
|
||||||
uint nb01;
|
|
||||||
uint nb02;
|
|
||||||
uint nb03;
|
|
||||||
int ne0;
|
|
||||||
uint nb0;
|
|
||||||
uint nb1;
|
|
||||||
uint nb2;
|
|
||||||
uint nb3;
|
|
||||||
} pcs;
|
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
const uint i3 = gl_WorkGroupID.z;
|
const uint i3 = gl_WorkGroupID.z;
|
||||||
const uint i2 = gl_WorkGroupID.y;
|
const uint i2 = gl_WorkGroupID.y;
|
||||||
const uint i1 = gl_WorkGroupID.x;
|
const uint i1 = gl_WorkGroupID.x;
|
||||||
|
|
||||||
const bool is_neox = (pcs.mode & 2) != 0;
|
const bool is_neox = (pcs.mode & 2) != 0;
|
||||||
|
|
||||||
|
float corr_dims[2];
|
||||||
|
rope_yarn_corr_dims(pcs.n_dims, pcs.n_orig_ctx, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
|
||||||
|
|
||||||
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
|
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
|
||||||
|
|
||||||
const int p = inB[pcs.inBOff + i2];
|
const int p = inB[pcs.inBOff + i2];
|
||||||
|
|
||||||
float theta = pcs.freq_scale * float(p);
|
float theta = float(p);
|
||||||
|
|
||||||
if (!is_neox) {
|
if (!is_neox) {
|
||||||
for (uint i0 = 0; i0 < pcs.ne0; i0 += 2) {
|
for (uint i0 = 0; i0 < pcs.ne0; i0 += 2) {
|
||||||
const float cos_theta = cos(theta);
|
float cos_theta, sin_theta;
|
||||||
const float sin_theta = sin(theta);
|
rope_yarn(theta, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
|
||||||
|
|
||||||
theta *= theta_scale;
|
theta *= theta_scale;
|
||||||
|
|
||||||
@ -68,8 +50,10 @@ void main() {
|
|||||||
const float inv_ndims = -1.f/pcs.n_dims;
|
const float inv_ndims = -1.f/pcs.n_dims;
|
||||||
for (uint ib = 0; ib < pcs.ne0/pcs.n_dims; ++ib) {
|
for (uint ib = 0; ib < pcs.ne0/pcs.n_dims; ++ib) {
|
||||||
for (uint ic = 0; ic < pcs.n_dims; ic += 2) {
|
for (uint ic = 0; ic < pcs.n_dims; ic += 2) {
|
||||||
const float cos_theta = cos(theta);
|
const uint cur_rot = ib * pcs.n_dims + ic;
|
||||||
const float sin_theta = sin(theta);
|
|
||||||
|
float cos_theta, sin_theta;
|
||||||
|
rope_yarn(theta, pcs.freq_scale, corr_dims, cur_rot, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
|
||||||
|
|
||||||
theta *= theta_scale;
|
theta *= theta_scale;
|
||||||
|
|
||||||
|
75
kompute/rope_common.comp
Normal file
75
kompute/rope_common.comp
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
/**
|
||||||
|
* Copyright (c) 2023 Nomic, Inc. All rights reserved.
|
||||||
|
*
|
||||||
|
* This software is licensed under the terms of the Software for Open Models License (SOM),
|
||||||
|
* version 1.0, as detailed in the LICENSE_SOM.txt file. A copy of this license should accompany
|
||||||
|
* this software. Except as expressly granted in the SOM license, all rights are reserved by Nomic, Inc.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "common.comp"
|
||||||
|
|
||||||
|
// TODO: use a local size of 32 or more (Metal uses 1024)
|
||||||
|
layout(local_size_x = 1) in;
|
||||||
|
|
||||||
|
layout (push_constant) uniform parameter {
|
||||||
|
uint inAOff;
|
||||||
|
uint inBOff;
|
||||||
|
uint outOff;
|
||||||
|
int n_dims;
|
||||||
|
int mode;
|
||||||
|
int n_orig_ctx;
|
||||||
|
float freq_base;
|
||||||
|
float freq_scale;
|
||||||
|
float ext_factor;
|
||||||
|
float attn_factor;
|
||||||
|
float beta_fast;
|
||||||
|
float beta_slow;
|
||||||
|
uint nb00;
|
||||||
|
uint nb01;
|
||||||
|
uint nb02;
|
||||||
|
uint nb03;
|
||||||
|
int ne0;
|
||||||
|
uint nb0;
|
||||||
|
uint nb1;
|
||||||
|
uint nb2;
|
||||||
|
uint nb3;
|
||||||
|
} pcs;
|
||||||
|
|
||||||
|
float rope_yarn_ramp(const float low, const float high, const float i0) {
|
||||||
|
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
||||||
|
return 1.0f - min(1.0f, max(0.0f, y));
|
||||||
|
}
|
||||||
|
|
||||||
|
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
|
||||||
|
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
||||||
|
void rope_yarn(
|
||||||
|
float theta_extrap, float freq_scale, float corr_dims[2], float i0, float ext_factor, float mscale,
|
||||||
|
out float cos_theta, out float sin_theta
|
||||||
|
) {
|
||||||
|
// Get n-d rotational scaling corrected for extrapolation
|
||||||
|
float theta_interp = freq_scale * theta_extrap;
|
||||||
|
float theta = theta_interp;
|
||||||
|
if (ext_factor != 0.0f) {
|
||||||
|
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
|
||||||
|
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
||||||
|
|
||||||
|
// Get n-d magnitude scaling corrected for interpolation
|
||||||
|
mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
|
||||||
|
}
|
||||||
|
cos_theta = cos(theta) * mscale;
|
||||||
|
sin_theta = sin(theta) * mscale;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
||||||
|
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
||||||
|
float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
|
||||||
|
return n_dims * log(n_orig_ctx / (n_rot * TWOPI_F)) / (2 * log(base));
|
||||||
|
}
|
||||||
|
|
||||||
|
void rope_yarn_corr_dims(
|
||||||
|
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, out float dims[2]
|
||||||
|
) {
|
||||||
|
// start and end correction dims
|
||||||
|
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
|
||||||
|
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user