llama : implement YaRN RoPE scaling (#2268)

Co-authored-by: cebtenzzre <cebtenzzre@gmail.com>
Co-authored-by: Jeffrey Quesnelle <jquesnelle@gmail.com>
This commit is contained in:
cebtenzzre 2023-11-01 18:04:33 -04:00 committed by GitHub
parent c43c2da8af
commit 898aeca90a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 763 additions and 257 deletions

View File

@ -219,12 +219,52 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break; break;
} }
params.rope_freq_scale = std::stof(argv[i]); params.rope_freq_scale = std::stof(argv[i]);
} else if (arg == "--rope-scaling") {
if (++i >= argc) {
invalid_param = true;
break;
}
std::string value(argv[i]);
/**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; }
else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; }
else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; }
else { invalid_param = true; break; }
} else if (arg == "--rope-scale") { } else if (arg == "--rope-scale") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.rope_freq_scale = 1.0f/std::stof(argv[i]); params.rope_freq_scale = 1.0f/std::stof(argv[i]);
} else if (arg == "--yarn-orig-ctx") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.yarn_orig_ctx = std::stoi(argv[i]);
} else if (arg == "--yarn-ext-factor") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.yarn_ext_factor = std::stof(argv[i]);
} else if (arg == "--yarn-attn-factor") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.yarn_attn_factor = std::stof(argv[i]);
} else if (arg == "--yarn-beta-fast") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.yarn_beta_fast = std::stof(argv[i]);
} else if (arg == "--yarn-beta-slow") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.yarn_beta_slow = std::stof(argv[i]);
} else if (arg == "--memory-f32") { } else if (arg == "--memory-f32") {
params.memory_f16 = false; params.memory_f16 = false;
} else if (arg == "--top-p") { } else if (arg == "--top-p") {
@ -716,9 +756,16 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --cfg-negative-prompt-file FNAME\n"); printf(" --cfg-negative-prompt-file FNAME\n");
printf(" negative prompt file to use for guidance. (default: empty)\n"); printf(" negative prompt file to use for guidance. (default: empty)\n");
printf(" --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", sparams.cfg_scale); printf(" --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", sparams.cfg_scale);
printf(" --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale\n"); printf(" --rope-scaling {none,linear,yarn}\n");
printf(" RoPE frequency scaling method, defaults to linear unless specified by the model\n");
printf(" --rope-scale N RoPE context scaling factor, expands context by a factor of N\n");
printf(" --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: loaded from model)\n"); printf(" --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: loaded from model)\n");
printf(" --rope-freq-scale N RoPE frequency linear scaling factor (default: loaded from model)\n"); printf(" --rope-freq-scale N RoPE frequency scaling factor, expands context by a factor of 1/N\n");
printf(" --yarn-orig-ctx N YaRN: original context size of model (default: 0 = model training context size)\n");
printf(" --yarn-ext-factor N YaRN: extrapolation mix factor (default: 1.0, 0.0 = full interpolation)\n");
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
printf(" --no-penalize-nl do not penalize newline token\n"); printf(" --no-penalize-nl do not penalize newline token\n");
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
@ -835,8 +882,14 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.f16_kv = params.memory_f16; cparams.f16_kv = params.memory_f16;
cparams.logits_all = params.logits_all; cparams.logits_all = params.logits_all;
cparams.embedding = params.embedding; cparams.embedding = params.embedding;
cparams.rope_scaling_type = params.rope_scaling_type;
cparams.rope_freq_base = params.rope_freq_base; cparams.rope_freq_base = params.rope_freq_base;
cparams.rope_freq_scale = params.rope_freq_scale; cparams.rope_freq_scale = params.rope_freq_scale;
cparams.yarn_ext_factor = params.yarn_ext_factor;
cparams.yarn_attn_factor = params.yarn_attn_factor;
cparams.yarn_beta_fast = params.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
return cparams; return cparams;
} }

View File

@ -9,6 +9,7 @@
#define LOG_NO_FILE_LINE_FUNCTION #define LOG_NO_FILE_LINE_FUNCTION
#include "log.h" #include "log.h"
#include <cmath>
#include <string> #include <string>
#include <vector> #include <vector>
#include <random> #include <random>
@ -54,6 +55,12 @@ struct gpt_params {
int32_t n_beams = 0; // if non-zero then use beam search of given width. int32_t n_beams = 0; // if non-zero then use beam search of given width.
float rope_freq_base = 0.0f; // RoPE base frequency float rope_freq_base = 0.0f; // RoPE base frequency
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
float yarn_ext_factor = NAN; // YaRN extrapolation mix factor
float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor
float yarn_beta_fast = 32.0f;// YaRN low correction dim
float yarn_beta_slow = 1.0f; // YaRN high correction dim
int32_t yarn_orig_ctx = 0; // YaRN original context length
int8_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED;
// // sampling parameters // // sampling parameters
struct llama_sampling_params sparams; struct llama_sampling_params sparams;

View File

@ -163,7 +163,8 @@ gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
if "rope_scaling" in hparams and hparams["rope_scaling"] != None and "factor" in hparams["rope_scaling"]: if "rope_scaling" in hparams and hparams["rope_scaling"] != None and "factor" in hparams["rope_scaling"]:
if "type" in hparams["rope_scaling"]: if "type" in hparams["rope_scaling"]:
if hparams["rope_scaling"]["type"] == "linear": if hparams["rope_scaling"]["type"] == "linear":
gguf_writer.add_rope_scale_linear(hparams["rope_scaling"]["factor"]) gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"])
# TOKENIZATION # TOKENIZATION

View File

@ -151,8 +151,11 @@ class Params:
n_head_kv: int n_head_kv: int
f_norm_eps: float f_norm_eps: float
rope_scaling_type: gguf.RopeScalingType | None = None
f_rope_freq_base: float | None = None f_rope_freq_base: float | None = None
f_rope_scale: float | None = None f_rope_scale: float | None = None
n_orig_ctx: int | None = None
rope_finetuned: bool | None = None
ftype: GGMLFileType | None = None ftype: GGMLFileType | None = None
@ -198,20 +201,20 @@ class Params:
def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params: def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params:
config = json.load(open(config_path)) config = json.load(open(config_path))
n_vocab = config["vocab_size"] rope_scaling_type = f_rope_scale = n_orig_ctx = rope_finetuned = None
n_embd = config["hidden_size"]
n_layer = config["num_hidden_layers"]
n_ff = config["intermediate_size"]
n_head = config["num_attention_heads"]
n_head_kv = config["num_key_value_heads"] if "num_key_value_heads" in config else n_head
f_norm_eps = config["rms_norm_eps"]
f_rope_freq_base = config["rope_theta"] if "rope_theta" in config else None
rope_scaling = config.get("rope_scaling") rope_scaling = config.get("rope_scaling")
if isinstance(rope_scaling, dict) and rope_scaling.get("type") == "linear":
f_rope_scale = config["rope_scaling"].get("factor") if rope_scaling is not None and (typ := rope_scaling.get("type")):
rope_factor = rope_scaling.get("factor")
f_rope_scale = rope_factor
if typ == "linear":
rope_scaling_type = gguf.RopeScalingType.LINEAR
elif typ == "yarn":
rope_scaling_type = gguf.RopeScalingType.YARN
n_orig_ctx = rope_scaling['original_max_position_embeddings']
rope_finetuned = rope_scaling['finetuned']
else: else:
f_rope_scale = None raise NotImplementedError(f'Unknown rope scaling type: {typ}')
if "max_sequence_length" in config: if "max_sequence_length" in config:
n_ctx = config["max_sequence_length"] n_ctx = config["max_sequence_length"]
@ -222,16 +225,19 @@ class Params:
"Suggestion: provide 'config.json' of the model in the same directory containing model files.") "Suggestion: provide 'config.json' of the model in the same directory containing model files.")
return Params( return Params(
n_vocab = n_vocab, n_vocab = config["vocab_size"],
n_embd = n_embd, n_embd = config["hidden_size"],
n_layer = n_layer, n_layer = config["num_hidden_layers"],
n_ctx = n_ctx, n_ctx = n_ctx,
n_ff = n_ff, n_ff = config["intermediate_size"],
n_head = n_head, n_head = (n_head := config["num_attention_heads"]),
n_head_kv = n_head_kv, n_head_kv = config.get("num_key_value_heads", n_head),
f_norm_eps = f_norm_eps, f_norm_eps = config["rms_norm_eps"],
f_rope_freq_base = f_rope_freq_base, f_rope_freq_base = config.get("rope_theta"),
rope_scaling_type = rope_scaling_type,
f_rope_scale = f_rope_scale, f_rope_scale = f_rope_scale,
n_orig_ctx = n_orig_ctx,
rope_finetuned = rope_finetuned,
) )
# LLaMA v2 70B params.json # LLaMA v2 70B params.json
@ -240,17 +246,8 @@ class Params:
def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params: def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
config = json.load(open(config_path)) config = json.load(open(config_path))
n_vocab = config["vocab_size"] if "vocab_size" in config else -1
n_embd = config["dim"]
n_layer = config["n_layers"]
n_ff = -1
n_head = config["n_heads"]
n_head_kv = config["n_kv_heads"] if "n_kv_heads" in config else n_head
f_norm_eps = config["norm_eps"]
f_rope_freq_base = config["rope_theta"] if "rope_theta" in config else None
# hack to determine LLaMA v1 vs v2 vs CodeLlama # hack to determine LLaMA v1 vs v2 vs CodeLlama
if f_rope_freq_base == 1000000: if config.get("rope_theta") == 1000000:
# CodeLlama # CodeLlama
n_ctx = 16384 n_ctx = 16384
elif config["norm_eps"] == 1e-05: elif config["norm_eps"] == 1e-05:
@ -260,22 +257,16 @@ class Params:
# LLaMA v1 # LLaMA v1
n_ctx = 2048 n_ctx = 2048
if n_vocab == -1:
n_vocab = model["tok_embeddings.weight"].shape[0]
if n_ff == -1:
n_ff = model["layers.0.feed_forward.w1.weight"].shape[0]
return Params( return Params(
n_vocab = n_vocab, n_vocab = config.get("vocab_size", model["tok_embeddings.weight"].shape[0]),
n_embd = n_embd, n_embd = config["dim"],
n_layer = n_layer, n_layer = config["n_layers"],
n_ctx = n_ctx, n_ctx = n_ctx,
n_ff = n_ff, n_ff = model["layers.0.feed_forward.w1.weight"].shape[0],
n_head = n_head, n_head = (n_head := config["n_heads"]),
n_head_kv = n_head_kv, n_head_kv = config.get("n_kv_heads", n_head),
f_norm_eps = f_norm_eps, f_norm_eps = config["norm_eps"],
f_rope_freq_base = f_rope_freq_base, f_rope_freq_base = config.get("rope_theta"),
) )
@staticmethod @staticmethod
@ -831,8 +822,16 @@ class OutputFile:
if params.f_rope_freq_base is not None: if params.f_rope_freq_base is not None:
self.gguf.add_rope_freq_base(params.f_rope_freq_base) self.gguf.add_rope_freq_base(params.f_rope_freq_base)
if params.f_rope_scale is not None: if params.rope_scaling_type:
self.gguf.add_rope_scale_linear(params.f_rope_scale) assert params.f_rope_scale is not None
self.gguf.add_rope_scaling_type(params.rope_scaling_type)
self.gguf.add_rope_scaling_factor(params.f_rope_scale)
if params.n_orig_ctx is not None:
self.gguf.add_rope_scaling_orig_ctx_len(params.n_orig_ctx)
if params.rope_finetuned is not None:
self.gguf.add_rope_scaling_finetuned(params.rope_finetuned)
if params.ftype is not None: if params.ftype is not None:
self.gguf.add_file_type(params.ftype) self.gguf.add_file_type(params.ftype)

View File

@ -642,8 +642,9 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
const int rope_mode = 0; const int rope_mode = 0;
return ggml_rope_custom(ctx, return ggml_rope_custom(ctx,
t, KQ_pos, n_rot, rope_mode, n_ctx, t, KQ_pos, n_rot, rope_mode, n_ctx, 0,
rope_freq_base, rope_freq_scale); rope_freq_base, rope_freq_scale, 0.0f, 0.0f, 0.0f, 0.0f
);
}; };
set_name(tokens_input, "tokens_input"); set_name(tokens_input, "tokens_input");

View File

@ -1758,8 +1758,14 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
printf(" -tb N, --threads-batch N number of threads to use during batch and prompt processing (default: same as --threads)\n"); printf(" -tb N, --threads-batch N number of threads to use during batch and prompt processing (default: same as --threads)\n");
printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx); printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
printf(" --rope-scaling {none,linear,yarn}\n");
printf(" RoPE frequency scaling method, defaults to linear unless specified by the model\n");
printf(" --rope-freq-base N RoPE base frequency (default: loaded from model)\n"); printf(" --rope-freq-base N RoPE base frequency (default: loaded from model)\n");
printf(" --rope-freq-scale N RoPE frequency scaling factor (default: loaded from model)\n"); printf(" --rope-freq-scale N RoPE frequency scaling factor, expands context by a factor of 1/N\n");
printf(" --yarn-ext-factor N YaRN: extrapolation mix factor (default: 1.0, 0.0 = full interpolation)\n");
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
printf(" not recommended: doubles context memory required and no measurable increase in quality\n"); printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
@ -1881,6 +1887,19 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
} }
params.n_ctx = std::stoi(argv[i]); params.n_ctx = std::stoi(argv[i]);
} }
else if (arg == "--rope-scaling")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
std::string value(argv[i]);
/**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; }
else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; }
else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; }
else { invalid_param = true; break; }
}
else if (arg == "--rope-freq-base") else if (arg == "--rope-freq-base")
{ {
if (++i >= argc) if (++i >= argc)
@ -1899,6 +1918,38 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
} }
params.rope_freq_scale = std::stof(argv[i]); params.rope_freq_scale = std::stof(argv[i]);
} }
else if (arg == "--yarn-ext-factor")
{
if (++i >= argc) {
invalid_param = true;
break;
}
params.yarn_ext_factor = std::stof(argv[i]);
}
else if (arg == "--yarn-attn-factor")
{
if (++i >= argc) {
invalid_param = true;
break;
}
params.yarn_attn_factor = std::stof(argv[i]);
}
else if (arg == "--yarn-beta-fast")
{
if (++i >= argc) {
invalid_param = true;
break;
}
params.yarn_beta_fast = std::stof(argv[i]);
}
else if (arg == "--yarn-beta-slow")
{
if (++i >= argc) {
invalid_param = true;
break;
}
params.yarn_beta_slow = std::stof(argv[i]);
}
else if (arg == "--memory-f32" || arg == "--memory_f32") else if (arg == "--memory-f32" || arg == "--memory_f32")
{ {
params.memory_f16 = false; params.memory_f16 = false;

View File

@ -349,9 +349,9 @@ static struct ggml_tensor * llama_build_train_graphs(
// not capturing these, to silcence warnings // not capturing these, to silcence warnings
const int rope_mode = 0; const int rope_mode = 0;
return ggml_rope_custom(ctx, return ggml_rope_custom(
t, KQ_pos, n_rot, rope_mode, n_ctx, ctx, t, KQ_pos, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
rope_freq_base, rope_freq_scale); );
}; };
set_name(tokens_input, "tokens_input"); set_name(tokens_input, "tokens_input");

View File

@ -4493,11 +4493,41 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
cpy_1(cx + x_offset, cdst + dst_offset); cpy_1(cx + x_offset, cdst + dst_offset);
} }
// rope == RoPE == rotary positional embedding static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
}
struct rope_corr_dims {
float v[4];
};
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
static __device__ void rope_yarn(
float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
float * cos_theta, float * sin_theta
) {
// Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap;
float theta = theta_interp;
if (ext_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
}
*cos_theta = cosf(theta) * mscale;
*sin_theta = sinf(theta) * mscale;
}
// rope == RoPE == rotary positional embedding
template<typename T, bool has_pos> template<typename T, bool has_pos>
static __global__ void rope(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale, static __global__ void rope(
const int p_delta_rows, const float theta_scale) { const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
float ext_factor, float attn_factor, rope_corr_dims corr_dims
) {
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (col >= ncols) { if (col >= ncols) {
@ -4509,10 +4539,10 @@ static __global__ void rope(const T * x, T * dst, const int ncols, const int32_t
const int i2 = row/p_delta_rows; const int i2 = row/p_delta_rows;
const int p = has_pos ? pos[i2] : 0; const int p = has_pos ? pos[i2] : 0;
const float p0 = p*freq_scale; const float theta_base = p*powf(freq_base, -col/ncols);
const float theta = p0*powf(theta_scale, col/2);
const float sin_theta = sinf(theta); float cos_theta, sin_theta;
const float cos_theta = cosf(theta); rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
const float x0 = x[i + 0]; const float x0 = x[i + 0];
const float x1 = x[i + 1]; const float x1 = x[i + 1];
@ -4522,8 +4552,10 @@ static __global__ void rope(const T * x, T * dst, const int ncols, const int32_t
} }
template<typename T, bool has_pos> template<typename T, bool has_pos>
static __global__ void rope_neox(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale, static __global__ void rope_neox(
const int p_delta_rows, const float theta_scale) { const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
float ext_factor, float attn_factor, rope_corr_dims corr_dims
) {
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (col >= ncols) { if (col >= ncols) {
@ -4534,11 +4566,14 @@ static __global__ void rope_neox(const T * x, T * dst, const int ncols, const in
const int i = row*ncols + col/2; const int i = row*ncols + col/2;
const int i2 = row/p_delta_rows; const int i2 = row/p_delta_rows;
// simplified from `(row * ncols + col) * (-1 / ncols)`
const float cur_rot = -col/ncols - row;
const int p = has_pos ? pos[i2] : 0; const int p = has_pos ? pos[i2] : 0;
const float p0 = p*freq_scale; const float theta_base = p*powf(freq_base, cur_rot);
const float theta = p0*powf(theta_scale, col/2);
const float sin_theta = sinf(theta); float cos_theta, sin_theta;
const float cos_theta = cosf(theta); rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
const float x0 = x[i + 0]; const float x0 = x[i + 0];
const float x1 = x[i + ncols/2]; const float x1 = x[i + ncols/2];
@ -4547,8 +4582,10 @@ static __global__ void rope_neox(const T * x, T * dst, const int ncols, const in
dst[i + ncols/2] = x0*sin_theta + x1*cos_theta; dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
} }
static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale, static __global__ void rope_glm_f32(
const int p_delta_rows, const float theta_scale, const int n_ctx) { const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
int n_ctx
) {
const int col = blockDim.x*blockIdx.x + threadIdx.x; const int col = blockDim.x*blockIdx.x + threadIdx.x;
const int half_n_dims = ncols/4; const int half_n_dims = ncols/4;
@ -4560,7 +4597,7 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
const int i = row*ncols + col; const int i = row*ncols + col;
const int i2 = row/p_delta_rows; const int i2 = row/p_delta_rows;
const float col_theta_scale = powf(theta_scale, col); const float col_theta_scale = powf(freq_base, -2.0f*col/ncols);
// FIXME: this is likely wrong // FIXME: this is likely wrong
const int p = pos != nullptr ? pos[i2] : 0; const int p = pos != nullptr ? pos[i2] : 0;
@ -5584,40 +5621,54 @@ static void clamp_f32_cuda(const float * x, float * dst, const float min, const
} }
template<typename T> template<typename T>
static void rope_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, static void rope_cuda(
const int p_delta_rows, const float theta_scale, cudaStream_t stream) { const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
) {
GGML_ASSERT(ncols % 2 == 0); GGML_ASSERT(ncols % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nrows, num_blocks_x, 1); const dim3 block_nums(nrows, num_blocks_x, 1);
if (pos == nullptr) { if (pos == nullptr) {
rope<T, false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); rope<T, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
);
} else { } else {
rope<T, true><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); rope<T, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
);
} }
} }
template<typename T> template<typename T>
static void rope_neox_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, static void rope_neox_cuda(
const int p_delta_rows, const float theta_scale, cudaStream_t stream) { const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
) {
GGML_ASSERT(ncols % 2 == 0); GGML_ASSERT(ncols % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nrows, num_blocks_x, 1); const dim3 block_nums(nrows, num_blocks_x, 1);
if (pos == nullptr) { if (pos == nullptr) {
rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
);
} else { } else {
rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
);
} }
} }
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, static void rope_glm_f32_cuda(
const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) { const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, int n_ctx, cudaStream_t stream
) {
GGML_ASSERT(ncols % 4 == 0); GGML_ASSERT(ncols % 4 == 0);
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1); const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1);
const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE; const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE;
const dim3 block_nums(num_blocks_x, nrows, 1); const dim3 block_nums(num_blocks_x, nrows, 1);
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale, n_ctx); rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, n_ctx);
} }
static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
@ -6481,13 +6532,16 @@ inline void ggml_cuda_op_rope(
const int n_dims = ((int32_t *) dst->op_params)[1]; const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2]; const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx = ((int32_t *) dst->op_params)[3]; const int n_ctx = ((int32_t *) dst->op_params)[3];
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
// RoPE alteration for extended context // RoPE alteration for extended context
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
float freq_base, freq_scale; memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
const float theta_scale = powf(freq_base, -2.0f/n_dims); memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
const int32_t * pos = nullptr; const int32_t * pos = nullptr;
if ((mode & 1) == 0) { if ((mode & 1) == 0) {
@ -6499,24 +6553,39 @@ inline void ggml_cuda_op_rope(
const bool is_neox = mode & 2; const bool is_neox = mode & 2;
const bool is_glm = mode & 4; const bool is_glm = mode & 4;
rope_corr_dims corr_dims;
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
// compute // compute
if (is_glm) { if (is_glm) {
GGML_ASSERT(false); GGML_ASSERT(false);
rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, n_ctx, main_stream); rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream);
} else if (is_neox) { } else if (is_neox) {
GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet"); GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
if (src0->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32) {
rope_neox_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); rope_neox_cuda(
(const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, main_stream
);
} else if (src0->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F16) {
rope_neox_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); rope_neox_cuda(
(const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, main_stream
);
} else { } else {
GGML_ASSERT(false); GGML_ASSERT(false);
} }
} else { } else {
if (src0->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32) {
rope_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); rope_cuda(
(const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, main_stream
);
} else if (src0->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F16) {
rope_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); rope_cuda(
(const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, main_stream
);
} else { } else {
GGML_ASSERT(false); GGML_ASSERT(false);
} }

View File

@ -1403,11 +1403,15 @@ void ggml_metal_graph_compute(
const int n_past = ((int32_t *) dst->op_params)[0]; const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1]; const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2]; const int mode = ((int32_t *) dst->op_params)[2];
const int n_orig_ctx = ((int32_t *) dst->op_params)[3];
float freq_base; float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
float freq_scale; memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break; case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break;
@ -1439,6 +1443,10 @@ void ggml_metal_graph_compute(
[encoder setBytes:&mode length:sizeof( int) atIndex:21]; [encoder setBytes:&mode length:sizeof( int) atIndex:21];
[encoder setBytes:&freq_base length:sizeof(float) atIndex:22]; [encoder setBytes:&freq_base length:sizeof(float) atIndex:22];
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:23]; [encoder setBytes:&freq_scale length:sizeof(float) atIndex:23];
[encoder setBytes:&ext_factor length:sizeof(float) atIndex:24];
[encoder setBytes:&attn_factor length:sizeof(float) atIndex:25];
[encoder setBytes:&beta_fast length:sizeof(float) atIndex:26];
[encoder setBytes:&beta_slow length:sizeof(float) atIndex:27];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break; } break;

View File

@ -1061,6 +1061,45 @@ kernel void kernel_alibi_f32(
} }
} }
static float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
}
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
static void rope_yarn(
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
float * cos_theta, float * sin_theta
) {
// Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap;
float theta = theta_interp;
if (ext_factor != 0.0f) {
ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
}
*cos_theta = cosf(theta) * mscale;
*sin_theta = sinf(theta) * mscale;
}
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
}
static void rope_yarn_corr_dims(
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
) {
// start and end correction dims
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
}
typedef void (rope_t)( typedef void (rope_t)(
device const void * src0, device const void * src0,
device const int32_t * src1, device const int32_t * src1,
@ -1116,6 +1155,10 @@ kernel void kernel_rope(
constant int & mode, constant int & mode,
constant float & freq_base, constant float & freq_base,
constant float & freq_scale, constant float & freq_scale,
constant float & ext_factor,
constant float & attn_factor,
constant float & beta_fast,
constant float & beta_slow,
uint tiitg[[thread_index_in_threadgroup]], uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg[[threads_per_threadgroup]], uint3 tptg[[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) { uint3 tgpig[[threadgroup_position_in_grid]]) {
@ -1125,19 +1168,22 @@ kernel void kernel_rope(
const bool is_neox = mode & 2; const bool is_neox = mode & 2;
float corr_dims[2];
rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
device const int32_t * pos = src1; device const int32_t * pos = src1;
const int64_t p = pos[i2]; const int64_t p = pos[i2];
const float theta_0 = freq_scale * (float)p; const float theta_0 = (float)p;
const float inv_ndims = -1.f/n_dims; const float inv_ndims = -1.f/n_dims;
if (!is_neox) { if (!is_neox) {
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
const float theta = theta_0 * pow(freq_base, inv_ndims*i0); const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
const float cos_theta = cos(theta); float cos_theta, sin_theta;
const float sin_theta = sin(theta); rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@ -1152,9 +1198,12 @@ kernel void kernel_rope(
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) { for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib); // simplified from `(ib * n_dims + ic) * inv_ndims`
const float cos_theta = cos(theta); const float cur_rot = inv_ndims*ic - ib;
const float sin_theta = sin(theta);
const float theta = theta_0 * pow(freq_base, cur_rot);
float cos_theta, sin_theta;
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
const int64_t i0 = ib*n_dims + ic/2; const int64_t i0 = ib*n_dims + ic/2;

225
ggml.c
View File

@ -1,4 +1,5 @@
#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
#define _USE_MATH_DEFINES // For M_PI on MSVC
#include "ggml-impl.h" #include "ggml-impl.h"
#include "ggml-quants.h" #include "ggml-quants.h"
@ -4845,8 +4846,13 @@ static struct ggml_tensor * ggml_rope_impl(
int n_dims, int n_dims,
int mode, int mode,
int n_ctx, int n_ctx,
int n_orig_ctx,
float freq_base, float freq_base,
float freq_scale, float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow,
float xpos_base, float xpos_base,
bool xpos_down, bool xpos_down,
bool inplace) { bool inplace) {
@ -4862,11 +4868,15 @@ static struct ggml_tensor * ggml_rope_impl(
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx }; int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
memcpy(params + 4, &freq_base, sizeof(float)); memcpy(params + 5, &freq_base, sizeof(float));
memcpy(params + 5, &freq_scale, sizeof(float)); memcpy(params + 6, &freq_scale, sizeof(float));
memcpy(params + 6, &xpos_base, sizeof(float)); memcpy(params + 7, &ext_factor, sizeof(float));
memcpy(params + 7, &xpos_down, sizeof(bool)); memcpy(params + 8, &attn_factor, sizeof(float));
memcpy(params + 9, &beta_fast, sizeof(float));
memcpy(params + 10, &beta_slow, sizeof(float));
memcpy(params + 11, &xpos_base, sizeof(float));
memcpy(params + 12, &xpos_down, sizeof(bool));
ggml_set_op_params(result, params, sizeof(params)); ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_ROPE; result->op = GGML_OP_ROPE;
@ -4884,7 +4894,9 @@ struct ggml_tensor * ggml_rope(
int n_dims, int n_dims,
int mode, int mode,
int n_ctx) { int n_ctx) {
return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false); return ggml_rope_impl(
ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
);
} }
struct ggml_tensor * ggml_rope_inplace( struct ggml_tensor * ggml_rope_inplace(
@ -4894,7 +4906,9 @@ struct ggml_tensor * ggml_rope_inplace(
int n_dims, int n_dims,
int mode, int mode,
int n_ctx) { int n_ctx) {
return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true); return ggml_rope_impl(
ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
);
} }
struct ggml_tensor * ggml_rope_custom( struct ggml_tensor * ggml_rope_custom(
@ -4904,9 +4918,17 @@ struct ggml_tensor * ggml_rope_custom(
int n_dims, int n_dims,
int mode, int mode,
int n_ctx, int n_ctx,
int n_orig_ctx,
float freq_base, float freq_base,
float freq_scale) { float freq_scale,
return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false); float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow) {
return ggml_rope_impl(
ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
);
} }
struct ggml_tensor * ggml_rope_custom_inplace( struct ggml_tensor * ggml_rope_custom_inplace(
@ -4916,9 +4938,17 @@ struct ggml_tensor * ggml_rope_custom_inplace(
int n_dims, int n_dims,
int mode, int mode,
int n_ctx, int n_ctx,
int n_orig_ctx,
float freq_base, float freq_base,
float freq_scale) { float freq_scale,
return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true); float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow) {
return ggml_rope_impl(
ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
);
} }
struct ggml_tensor * ggml_rope_xpos_inplace( struct ggml_tensor * ggml_rope_xpos_inplace(
@ -4928,7 +4958,7 @@ struct ggml_tensor * ggml_rope_xpos_inplace(
int n_dims, int n_dims,
float base, float base,
bool down) { bool down) {
return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true); return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
} }
// ggml_rope_back // ggml_rope_back
@ -10901,6 +10931,45 @@ static void ggml_compute_forward_clamp(
// ggml_compute_forward_rope // ggml_compute_forward_rope
static float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
return 1 - MIN(1, MAX(0, y));
}
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
static void rope_yarn(
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
float * cos_theta, float * sin_theta
) {
// Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap;
float theta = theta_interp;
if (ext_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
}
*cos_theta = cosf(theta) * mscale;
*sin_theta = sinf(theta) * mscale;
}
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) {
return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
}
void ggml_rope_yarn_corr_dims(
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
) {
// start and end correction dims
dims[0] = MAX(0, floorf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base)));
dims[1] = MIN(n_dims - 1, ceilf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base)));
}
static void ggml_compute_forward_rope_f32( static void ggml_compute_forward_rope_f32(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
@ -10910,8 +10979,7 @@ static void ggml_compute_forward_rope_f32(
return; return;
} }
float freq_base; float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
float freq_scale;
// these two only relevant for xPos RoPE: // these two only relevant for xPos RoPE:
float xpos_base; float xpos_base;
@ -10921,10 +10989,16 @@ static void ggml_compute_forward_rope_f32(
const int n_dims = ((int32_t *) dst->op_params)[1]; const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2]; const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx = ((int32_t *) dst->op_params)[3]; const int n_ctx = ((int32_t *) dst->op_params)[3];
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float)); memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool)); memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
memcpy(&xpos_base, (int32_t *) dst->op_params + 11, sizeof(float));
memcpy(&xpos_down, (int32_t *) dst->op_params + 12, sizeof(bool));
GGML_TENSOR_UNARY_OP_LOCALS GGML_TENSOR_UNARY_OP_LOCALS
@ -10952,6 +11026,9 @@ static void ggml_compute_forward_rope_f32(
int ir = 0; int ir = 0;
const float theta_scale = powf(freq_base, -2.0f/n_dims); const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float inv_ndims = -1.f/n_dims;
float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
const bool is_neox = mode & 2; const bool is_neox = mode & 2;
const bool is_glm = mode & 4; const bool is_glm = mode & 4;
@ -10965,18 +11042,18 @@ static void ggml_compute_forward_rope_f32(
if (ir++ < ir0) continue; if (ir++ < ir0) continue;
if (ir > ir1) break; if (ir > ir1) break;
float theta = freq_scale * (float)p; float theta_base = (float)p;
if (is_glm) { if (is_glm) {
theta = MIN(p, n_ctx - 2); theta_base = MIN(p, n_ctx - 2);
float block_theta = MAX(p - (n_ctx - 2), 0); float block_theta = MAX(p - (n_ctx - 2), 0);
for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
const float cos_theta = cosf(theta); const float cos_theta = cosf(theta_base);
const float sin_theta = sinf(theta); const float sin_theta = sinf(theta_base);
const float cos_block_theta = cosf(block_theta); const float cos_block_theta = cosf(block_theta);
const float sin_block_theta = sinf(block_theta); const float sin_block_theta = sinf(block_theta);
theta *= theta_scale; theta_base *= theta_scale;
block_theta *= theta_scale; block_theta *= theta_scale;
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@ -10994,13 +11071,16 @@ static void ggml_compute_forward_rope_f32(
} }
} else if (!is_neox) { } else if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float cos_theta = cosf(theta); float cos_theta, sin_theta;
const float sin_theta = sinf(theta); rope_yarn(
theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
);
// zeta scaling for xPos only: // zeta scaling for xPos only:
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f; float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
if (xpos_down) zeta = 1.0f / zeta; if (xpos_down) zeta = 1.0f / zeta;
theta *= theta_scale; theta_base *= theta_scale;
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@ -11014,12 +11094,19 @@ static void ggml_compute_forward_rope_f32(
} else { } else {
// TODO: this might be wrong for ne0 != n_dims - need double check // TODO: this might be wrong for ne0 != n_dims - need double check
// ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
theta_base *= freq_scale;
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 0; ic < n_dims; ic += 2) { for (int64_t ic = 0; ic < n_dims; ic += 2) {
const float cos_theta = cosf(theta); // simplified from `(ib * n_dims + ic) * inv_ndims`
const float sin_theta = sinf(theta); float cur_rot = inv_ndims * ic - ib;
theta *= theta_scale; float cos_theta, sin_theta;
rope_yarn(
theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
&cos_theta, &sin_theta
);
theta_base *= theta_scale;
const int64_t i0 = ib*n_dims + ic/2; const int64_t i0 = ib*n_dims + ic/2;
@ -11048,15 +11135,19 @@ static void ggml_compute_forward_rope_f16(
return; return;
} }
float freq_base; float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
float freq_scale;
//const int n_past = ((int32_t *) dst->op_params)[0]; //const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1]; const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2]; const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx = ((int32_t *) dst->op_params)[3]; const int n_ctx = ((int32_t *) dst->op_params)[3];
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
GGML_TENSOR_UNARY_OP_LOCALS GGML_TENSOR_UNARY_OP_LOCALS
@ -11084,6 +11175,9 @@ static void ggml_compute_forward_rope_f16(
int ir = 0; int ir = 0;
const float theta_scale = powf(freq_base, -2.0f/n_dims); const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float inv_ndims = -1.f/n_dims;
float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
const bool is_neox = mode & 2; const bool is_neox = mode & 2;
const bool is_glm = mode & 4; const bool is_glm = mode & 4;
@ -11097,18 +11191,18 @@ static void ggml_compute_forward_rope_f16(
if (ir++ < ir0) continue; if (ir++ < ir0) continue;
if (ir > ir1) break; if (ir > ir1) break;
float theta = freq_scale * (float)p; float theta_base = (float)p;
if (is_glm) { if (is_glm) {
theta = MIN(p, n_ctx - 2); theta_base = MIN(p, n_ctx - 2);
float block_theta = MAX(p - (n_ctx - 2), 0); float block_theta = MAX(p - (n_ctx - 2), 0);
for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
const float cos_theta = cosf(theta); const float cos_theta = cosf(theta_base);
const float sin_theta = sinf(theta); const float sin_theta = sinf(theta_base);
const float cos_block_theta = cosf(block_theta); const float cos_block_theta = cosf(block_theta);
const float sin_block_theta = sinf(block_theta); const float sin_block_theta = sinf(block_theta);
theta *= theta_scale; theta_base *= theta_scale;
block_theta *= theta_scale; block_theta *= theta_scale;
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@ -11126,10 +11220,12 @@ static void ggml_compute_forward_rope_f16(
} }
} else if (!is_neox) { } else if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float cos_theta = cosf(theta); float cos_theta, sin_theta;
const float sin_theta = sinf(theta); rope_yarn(
theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
);
theta *= theta_scale; theta_base *= theta_scale;
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@ -11143,12 +11239,19 @@ static void ggml_compute_forward_rope_f16(
} else { } else {
// TODO: this might be wrong for ne0 != n_dims - need double check // TODO: this might be wrong for ne0 != n_dims - need double check
// ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
theta_base *= freq_scale;
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 0; ic < n_dims; ic += 2) { for (int64_t ic = 0; ic < n_dims; ic += 2) {
const float cos_theta = cosf(theta); // simplified from `(ib * n_dims + ic) * inv_ndims`
const float sin_theta = sinf(theta); float cur_rot = inv_ndims * ic - ib;
theta *= theta_scale; float cos_theta, sin_theta;
rope_yarn(
theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
&cos_theta, &sin_theta
);
theta_base *= theta_scale;
const int64_t i0 = ib*n_dims + ic/2; const int64_t i0 = ib*n_dims + ic/2;
@ -11256,17 +11359,18 @@ static void ggml_compute_forward_rope_back_f32(
if (ir++ < ir0) continue; if (ir++ < ir0) continue;
if (ir > ir1) break; if (ir > ir1) break;
float theta = freq_scale * (float)p; float theta_base = freq_scale * (float)p;
if (!is_neox) { if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float cos_theta = cosf(theta); const float cos_theta = cosf(theta_base);
const float sin_theta = sinf(theta); const float sin_theta = sinf(theta_base);
// zeta scaling for xPos only: // zeta scaling for xPos only:
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f; float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
if (xpos_down) zeta = 1.0f / zeta; if (xpos_down) zeta = 1.0f / zeta;
theta *= theta_scale; theta_base *= theta_scale;
const float * const dy = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); const float * const dy = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dx = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); float * dx = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@ -11280,10 +11384,10 @@ static void ggml_compute_forward_rope_back_f32(
} else { } else {
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 0; ic < n_dims; ic += 2) { for (int64_t ic = 0; ic < n_dims; ic += 2) {
const float cos_theta = cosf(theta); const float cos_theta = cosf(theta_base);
const float sin_theta = sinf(theta); const float sin_theta = sinf(theta_base);
theta *= theta_scale; theta_base *= theta_scale;
const int64_t i0 = ib*n_dims + ic/2; const int64_t i0 = ib*n_dims + ic/2;
@ -11356,14 +11460,14 @@ static void ggml_compute_forward_rope_back_f16(
if (ir++ < ir0) continue; if (ir++ < ir0) continue;
if (ir > ir1) break; if (ir > ir1) break;
float theta = (float)p; float theta_base = (float)p;
if (!is_neox) { if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float cos_theta = cosf(theta); const float cos_theta = cosf(theta_base);
const float sin_theta = sinf(theta); const float sin_theta = sinf(theta_base);
theta *= theta_scale; theta_base *= theta_scale;
const ggml_fp16_t * const dy = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); const ggml_fp16_t * const dy = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_fp16_t * dx = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); ggml_fp16_t * dx = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@ -11377,10 +11481,10 @@ static void ggml_compute_forward_rope_back_f16(
} else { } else {
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 0; ic < n_dims; ic += 2) { for (int64_t ic = 0; ic < n_dims; ic += 2) {
const float cos_theta = cosf(theta); const float cos_theta = cosf(theta_base);
const float sin_theta = sinf(theta); const float sin_theta = sinf(theta_base);
theta *= theta_scale; theta_base *= theta_scale;
const int64_t i0 = ib*n_dims + ic/2; const int64_t i0 = ib*n_dims + ic/2;
@ -15505,9 +15609,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src1, src1,
n_dims, n_dims,
mode, mode,
0,
n_ctx, n_ctx,
freq_base, freq_base,
freq_scale, freq_scale,
0.0f,
1.0f,
0.0f,
0.0f,
xpos_base, xpos_base,
xpos_down, xpos_down,
false), false),

20
ggml.h
View File

@ -219,7 +219,7 @@
#define GGML_MAX_CONTEXTS 64 #define GGML_MAX_CONTEXTS 64
#define GGML_MAX_SRC 6 #define GGML_MAX_SRC 6
#define GGML_MAX_NAME 64 #define GGML_MAX_NAME 64
#define GGML_MAX_OP_PARAMS 32 #define GGML_MAX_OP_PARAMS 64
#define GGML_DEFAULT_N_THREADS 4 #define GGML_DEFAULT_N_THREADS 4
#if UINTPTR_MAX == 0xFFFFFFFF #if UINTPTR_MAX == 0xFFFFFFFF
@ -1326,8 +1326,13 @@ extern "C" {
int n_dims, int n_dims,
int mode, int mode,
int n_ctx, int n_ctx,
int n_orig_ctx,
float freq_base, float freq_base,
float freq_scale); float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow);
// in-place, returns view(a) // in-place, returns view(a)
GGML_API struct ggml_tensor * ggml_rope_custom_inplace( GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
@ -1337,8 +1342,17 @@ extern "C" {
int n_dims, int n_dims,
int mode, int mode,
int n_ctx, int n_ctx,
int n_orig_ctx,
float freq_base, float freq_base,
float freq_scale); float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow);
// compute correction dims for YaRN RoPE scaling
void ggml_rope_yarn_corr_dims(
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
// xPos RoPE, in-place, returns view(a) // xPos RoPE, in-place, returns view(a)
GGML_API struct ggml_tensor * ggml_rope_xpos_inplace( GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(

View File

@ -7,7 +7,7 @@ import shutil
import struct import struct
import sys import sys
import tempfile import tempfile
from enum import IntEnum, auto from enum import Enum, IntEnum, auto
from io import BufferedWriter from io import BufferedWriter
from pathlib import Path from pathlib import Path
from typing import IO, Any, BinaryIO, Callable, Sequence from typing import IO, Any, BinaryIO, Callable, Sequence
@ -55,7 +55,10 @@ KEY_ATTENTION_LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
# RoPE # RoPE
KEY_ROPE_DIMENSION_COUNT = "{arch}.rope.dimension_count" KEY_ROPE_DIMENSION_COUNT = "{arch}.rope.dimension_count"
KEY_ROPE_FREQ_BASE = "{arch}.rope.freq_base" KEY_ROPE_FREQ_BASE = "{arch}.rope.freq_base"
KEY_ROPE_SCALE_LINEAR = "{arch}.rope.scale_linear" KEY_ROPE_SCALING_TYPE = "{arch}.rope.scaling.type"
KEY_ROPE_SCALING_FACTOR = "{arch}.rope.scaling.factor"
KEY_ROPE_SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
KEY_ROPE_SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
# tokenization # tokenization
KEY_TOKENIZER_MODEL = "tokenizer.ggml.model" KEY_TOKENIZER_MODEL = "tokenizer.ggml.model"
@ -577,6 +580,11 @@ class TokenType(IntEnum):
UNUSED = 5 UNUSED = 5
BYTE = 6 BYTE = 6
class RopeScalingType(Enum):
NONE = 'none'
LINEAR = 'linear'
YARN = 'yarn'
# #
# implementation # implementation
# #
@ -948,8 +956,17 @@ class GGUFWriter:
def add_rope_freq_base(self, value: float): def add_rope_freq_base(self, value: float):
self.add_float32(KEY_ROPE_FREQ_BASE.format(arch=self.arch), value) self.add_float32(KEY_ROPE_FREQ_BASE.format(arch=self.arch), value)
def add_rope_scale_linear(self, value: float): def add_rope_scaling_type(self, value: RopeScalingType):
self.add_float32(KEY_ROPE_SCALE_LINEAR.format(arch=self.arch), value) self.add_string(KEY_ROPE_SCALING_TYPE.format(arch=self.arch), value.value)
def add_rope_scaling_factor(self, value: float):
self.add_float32(KEY_ROPE_SCALING_FACTOR.format(arch=self.arch), value)
def add_rope_scaling_orig_ctx_len(self, value: int):
self.add_uint32(KEY_ROPE_SCALING_ORIG_CTX_LEN.format(arch=self.arch), value)
def add_rope_scaling_finetuned(self, value: bool):
self.add_bool(KEY_ROPE_SCALING_FINETUNED.format(arch=self.arch), value)
def add_tokenizer_model(self, model: str): def add_tokenizer_model(self, model: str):
self.add_string(KEY_TOKENIZER_MODEL, model) self.add_string(KEY_TOKENIZER_MODEL, model)

192
llama.cpp
View File

@ -54,6 +54,7 @@
#include <cassert> #include <cassert>
#include <cinttypes> #include <cinttypes>
#include <climits> #include <climits>
#include <cmath>
#include <cstdarg> #include <cstdarg>
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
@ -235,6 +236,10 @@ enum llm_kv {
LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_FREQ_BASE, LLM_KV_ROPE_FREQ_BASE,
LLM_KV_ROPE_SCALE_LINEAR, LLM_KV_ROPE_SCALE_LINEAR,
LLM_KV_ROPE_SCALING_TYPE,
LLM_KV_ROPE_SCALING_FACTOR,
LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
LLM_KV_ROPE_SCALING_FINETUNED,
LLM_KV_TOKENIZER_MODEL, LLM_KV_TOKENIZER_MODEL,
LLM_KV_TOKENIZER_LIST, LLM_KV_TOKENIZER_LIST,
@ -279,6 +284,10 @@ static std::map<llm_kv, std::string> LLM_KV_NAMES = {
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
{ LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" },
{ LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" },
{ LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
{ LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
@ -552,6 +561,22 @@ do { \
} \ } \
} while (0) } while (0)
static std::map<int8_t, std::string> LLAMA_ROPE_SCALING_TYPES = {
{ LLAMA_ROPE_SCALING_NONE, "none" },
{ LLAMA_ROPE_SCALING_LINEAR, "linear" },
{ LLAMA_ROPE_SCALING_YARN, "yarn" },
};
static int8_t llama_rope_scaling_type_from_string(const std::string & name) {
for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) {
if (kv.second == name) {
return kv.first;
}
}
return LLAMA_ROPE_SCALING_UNSPECIFIED;
}
// //
// ggml helpers // ggml helpers
// //
@ -1037,6 +1062,9 @@ struct llama_hparams {
float rope_freq_base_train; float rope_freq_base_train;
float rope_freq_scale_train; float rope_freq_scale_train;
uint32_t n_yarn_orig_ctx;
int8_t rope_scaling_type_train : 3;
bool rope_finetuned : 1;
float f_clamp_kqv; float f_clamp_kqv;
float f_max_alibi_bias; float f_max_alibi_bias;
@ -1051,6 +1079,8 @@ struct llama_hparams {
if (this->n_layer != other.n_layer) return true; if (this->n_layer != other.n_layer) return true;
if (this->n_rot != other.n_rot) return true; if (this->n_rot != other.n_rot) return true;
if (this->n_ff != other.n_ff) return true; if (this->n_ff != other.n_ff) return true;
if (this->rope_finetuned != other.rope_finetuned) return true;
if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true;
const float EPSILON = 1e-9; const float EPSILON = 1e-9;
@ -1084,6 +1114,14 @@ struct llama_cparams {
float rope_freq_base; float rope_freq_base;
float rope_freq_scale; float rope_freq_scale;
uint32_t n_yarn_orig_ctx;
// These hyperparameters are not exposed in GGUF, because all
// existing YaRN models use the same values for them.
float yarn_ext_factor;
float yarn_attn_factor;
float yarn_beta_fast;
float yarn_beta_slow;
bool mul_mat_q; bool mul_mat_q;
}; };
@ -2014,14 +2052,30 @@ static void llm_load_hparams(
hparams.n_head_kv = hparams.n_head; hparams.n_head_kv = hparams.n_head;
GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV)); GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV));
hparams.rope_finetuned = false;
GGUF_GET_KEY(ctx, hparams.rope_finetuned, gguf_get_val_bool, GGUF_TYPE_BOOL, false,
kv(LLM_KV_ROPE_SCALING_FINETUNED));
hparams.n_yarn_orig_ctx = hparams.n_ctx_train;
GGUF_GET_KEY(ctx, hparams.n_yarn_orig_ctx, gguf_get_val_u32, GGUF_TYPE_UINT32, false,
kv(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN));
// rope_freq_base (optional) // rope_freq_base (optional)
hparams.rope_freq_base_train = 10000.0f; hparams.rope_freq_base_train = 10000.0f;
GGUF_GET_KEY(ctx, hparams.rope_freq_base_train, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE)); GGUF_GET_KEY(ctx, hparams.rope_freq_base_train, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
std::string rope_scaling("linear");
GGUF_GET_KEY(ctx, rope_scaling, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_ROPE_SCALING_TYPE));
hparams.rope_scaling_type_train = llama_rope_scaling_type_from_string(rope_scaling);
GGML_ASSERT(hparams.rope_scaling_type_train != LLAMA_ROPE_SCALING_UNSPECIFIED);
// rope_freq_scale (inverse of the kv) is optional // rope_freq_scale (inverse of the kv) is optional
float ropescale = 1.0f; float ropescale = 0.0f;
GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALING_FACTOR));
if (ropescale == 0.0f) { // try the old key name
GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR)); GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
hparams.rope_freq_scale_train = 1.0f/ropescale; }
hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale;
// sanity check for n_rot (optional) // sanity check for n_rot (optional)
{ {
@ -2371,6 +2425,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
const auto & vocab = model.vocab; const auto & vocab = model.vocab;
const auto rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train);
// hparams // hparams
LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver)); LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver));
LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str()); LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str());
@ -2389,8 +2445,11 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv); LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv);
LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias); LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias);
LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff); LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str());
LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
LLAMA_LOG_INFO("%s: n_yarn_orig_ctx = %u\n", __func__, hparams.n_yarn_orig_ctx);
LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown");
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type));
LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str()); LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str());
LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, ml.n_elements*1e-9); LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, ml.n_elements*1e-9);
@ -3047,21 +3106,11 @@ static void llm_load_tensors(
model.t_load_us = ggml_time_us() - model.t_start_us; model.t_load_us = ggml_time_us() - model.t_start_us;
} }
static bool llama_model_load( static bool llama_model_load(const std::string & fname, llama_model & model, const llama_model_params & params) {
const std::string & fname,
llama_model & model,
int n_gpu_layers,
int main_gpu,
const float * tensor_split,
bool use_mmap,
bool use_mlock,
bool vocab_only,
llama_progress_callback progress_callback,
void *progress_callback_user_data) {
try { try {
llama_model_loader ml(fname, use_mmap); llama_model_loader ml(fname, params.use_mmap);
model.hparams.vocab_only = vocab_only; model.hparams.vocab_only = params.vocab_only;
llm_load_arch (ml, model); llm_load_arch (ml, model);
llm_load_hparams(ml, model); llm_load_hparams(ml, model);
@ -3073,15 +3122,15 @@ static bool llama_model_load(
throw std::runtime_error("vocab size mismatch"); throw std::runtime_error("vocab size mismatch");
} }
if (vocab_only) { if (params.vocab_only) {
LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__); LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__);
return true; return true;
} }
llm_load_tensors( llm_load_tensors(
ml, model, n_gpu_layers, ml, model, params.n_gpu_layers, params.main_gpu, params.tensor_split, params.use_mlock,
main_gpu, tensor_split, params.progress_callback, params.progress_callback_user_data
use_mlock, progress_callback, progress_callback_user_data); );
} catch (const std::exception & err) { } catch (const std::exception & err) {
LLAMA_LOG_ERROR("error loading model: %s\n", err.what()); LLAMA_LOG_ERROR("error loading model: %s\n", err.what());
return false; return false;
@ -3150,6 +3199,7 @@ static struct ggml_tensor * llm_build_inp_embd(
static void llm_build_k_shift( static void llm_build_k_shift(
struct ggml_context * ctx, struct ggml_context * ctx,
const llama_hparams & hparams, const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache & kv, const llama_kv_cache & kv,
struct ggml_cgraph * graph, struct ggml_cgraph * graph,
llm_rope_type type, llm_rope_type type,
@ -3162,6 +3212,11 @@ static void llm_build_k_shift(
const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_embd_gqa = hparams.n_embd_gqa(); const int64_t n_embd_gqa = hparams.n_embd_gqa();
const int64_t n_embd_head = hparams.n_embd_head(); const int64_t n_embd_head = hparams.n_embd_head();
const int32_t n_orig_ctx = cparams.n_yarn_orig_ctx;
const float ext_factor = cparams.yarn_ext_factor;
const float attn_factor = cparams.yarn_attn_factor;
const float beta_fast = cparams.yarn_beta_fast;
const float beta_slow = cparams.yarn_beta_slow;
GGML_ASSERT(n_embd_head % n_rot == 0); GGML_ASSERT(n_embd_head % n_rot == 0);
@ -3185,7 +3240,8 @@ static void llm_build_k_shift(
ggml_element_size(kv.k)*n_embd_head, ggml_element_size(kv.k)*n_embd_head,
ggml_element_size(kv.k)*n_embd_gqa, ggml_element_size(kv.k)*n_embd_gqa,
ggml_element_size(kv.k)*n_embd_gqa*n_ctx*il), ggml_element_size(kv.k)*n_embd_gqa*n_ctx*il),
K_shift, n_rot, rope_type, 0, freq_base, freq_scale); K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(tmp, "K_shifted", il); cb(tmp, "K_shifted", il);
ggml_build_forward_expand(graph, tmp); ggml_build_forward_expand(graph, tmp);
} }
@ -3442,12 +3498,17 @@ struct llm_build_context {
const float freq_base; const float freq_base;
const float freq_scale; const float freq_scale;
const float ext_factor;
const float attn_factor;
const float beta_fast;
const float beta_slow;
const float norm_eps; const float norm_eps;
const float norm_rms_eps; const float norm_rms_eps;
const int32_t n_tokens; const int32_t n_tokens;
const int32_t n_kv; // size of KV cache to consider (n_kv <= n_ctx) const int32_t n_kv; // size of KV cache to consider (n_kv <= n_ctx)
const int32_t kv_head; // index of where we store new KV data in the cache const int32_t kv_head; // index of where we store new KV data in the cache
const int32_t n_orig_ctx;
const bool do_rope_shift; const bool do_rope_shift;
@ -3477,11 +3538,16 @@ struct llm_build_context {
n_embd_gqa (hparams.n_embd_gqa()), n_embd_gqa (hparams.n_embd_gqa()),
freq_base (cparams.rope_freq_base), freq_base (cparams.rope_freq_base),
freq_scale (cparams.rope_freq_scale), freq_scale (cparams.rope_freq_scale),
ext_factor (cparams.yarn_ext_factor),
attn_factor (cparams.yarn_attn_factor),
beta_fast (cparams.yarn_beta_fast),
beta_slow (cparams.yarn_beta_slow),
norm_eps (hparams.f_norm_eps), norm_eps (hparams.f_norm_eps),
norm_rms_eps (hparams.f_norm_rms_eps), norm_rms_eps (hparams.f_norm_rms_eps),
n_tokens (batch.n_tokens), n_tokens (batch.n_tokens),
n_kv (worst_case ? n_ctx : kv_self.n), n_kv (worst_case ? n_ctx : kv_self.n),
kv_head (worst_case ? n_ctx - n_tokens : kv_self.head), kv_head (worst_case ? n_ctx - n_tokens : kv_self.head),
n_orig_ctx (cparams.n_yarn_orig_ctx),
do_rope_shift (worst_case || kv_self.has_shift), do_rope_shift (worst_case || kv_self.has_shift),
cb (cb), cb (cb),
buf_compute (lctx.buf_compute) { buf_compute (lctx.buf_compute) {
@ -3532,7 +3598,7 @@ struct llm_build_context {
// shift the entire K-cache if needed // shift the entire K-cache if needed
if (do_rope_shift) { if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb); llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
} }
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
@ -3556,10 +3622,18 @@ struct llm_build_context {
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il); cb(Vcur, "Vcur", il);
Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale); Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale); Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
@ -3634,7 +3708,7 @@ struct llm_build_context {
// shift the entire K-cache if needed // shift the entire K-cache if needed
if (do_rope_shift) { if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb); llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
} }
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
@ -3658,8 +3732,16 @@ struct llm_build_context {
switch (model.type) { switch (model.type) {
case MODEL_7B: case MODEL_7B:
Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale); Qcur = ggml_rope_custom(
Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale); ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
break; break;
case MODEL_13B: case MODEL_13B:
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd/n_head, n_head, n_tokens); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd/n_head, n_head, n_tokens);
@ -3746,7 +3828,7 @@ struct llm_build_context {
// shift the entire K-cache if needed // shift the entire K-cache if needed
if (do_rope_shift) { if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb); llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
} }
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
@ -3786,10 +3868,16 @@ struct llm_build_context {
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
// using mode = 2 for neox mode // using mode = 2 for neox mode
Qcur = ggml_rope_custom(ctx0, Qcur, inp_pos, n_embd_head, 2, 0, freq_base, freq_scale); Qcur = ggml_rope_custom(
ctx0, Qcur, inp_pos, n_embd_head, 2, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(ctx0, Kcur, inp_pos, n_embd_head, 2, 0, freq_base, freq_scale); Kcur = ggml_rope_custom(
ctx0, Kcur, inp_pos, n_embd_head, 2, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
@ -3960,7 +4048,7 @@ struct llm_build_context {
cb(KQ_mask, "KQ_mask", -1); cb(KQ_mask, "KQ_mask", -1);
if (do_rope_shift) { if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb); llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
} }
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
@ -4053,12 +4141,14 @@ struct llm_build_context {
cb(kpass, "kpass", il); cb(kpass, "kpass", il);
struct ggml_tensor * qrotated = ggml_rope_custom( struct ggml_tensor * qrotated = ggml_rope_custom(
ctx0, qrot, inp_pos, n_rot, 2, 0, freq_base, freq_scale ctx0, qrot, inp_pos, n_rot, 2, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
); );
cb(qrotated, "qrotated", il); cb(qrotated, "qrotated", il);
struct ggml_tensor * krotated = ggml_rope_custom( struct ggml_tensor * krotated = ggml_rope_custom(
ctx0, krot, inp_pos, n_rot, 2, 0, freq_base, freq_scale ctx0, krot, inp_pos, n_rot, 2, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
); );
cb(krotated, "krotated", il); cb(krotated, "krotated", il);
@ -7883,8 +7973,13 @@ struct llama_context_params llama_context_default_params() {
/*.n_batch =*/ 512, /*.n_batch =*/ 512,
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_UNSPECIFIED,
/*.rope_freq_base =*/ 0.0f, /*.rope_freq_base =*/ 0.0f,
/*.rope_freq_scale =*/ 0.0f, /*.rope_freq_scale =*/ 0.0f,
/*.yarn_ext_factor =*/ NAN,
/*.yarn_attn_factor =*/ 1.0f,
/*.yarn_beta_fast =*/ 32.0f,
/*.yarn_beta_slow =*/ 1.0f,
/*.mul_mat_q =*/ true, /*.mul_mat_q =*/ true,
/*.f16_kv =*/ true, /*.f16_kv =*/ true,
/*.logits_all =*/ false, /*.logits_all =*/ false,
@ -7971,10 +8066,7 @@ struct llama_model * llama_load_model_from_file(
}; };
} }
if (!llama_model_load(path_model, *model, params.n_gpu_layers, if (!llama_model_load(path_model, *model, params)) {
params.main_gpu, params.tensor_split,
params.use_mmap, params.use_mlock, params.vocab_only,
params.progress_callback, params.progress_callback_user_data)) {
LLAMA_LOG_ERROR("%s: failed to load model\n", __func__); LLAMA_LOG_ERROR("%s: failed to load model\n", __func__);
delete model; delete model;
return nullptr; return nullptr;
@ -8001,13 +8093,35 @@ struct llama_context * llama_new_context_with_model(
auto & cparams = ctx->cparams; auto & cparams = ctx->cparams;
cparams.n_batch = params.n_batch; cparams.n_batch = params.n_batch;
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
cparams.rope_freq_base = params.rope_freq_base == 0 ? hparams.rope_freq_base_train : params.rope_freq_base;
cparams.rope_freq_scale = params.rope_freq_scale == 0 ? hparams.rope_freq_scale_train : params.rope_freq_scale;
cparams.n_threads = params.n_threads; cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch; cparams.n_threads_batch = params.n_threads_batch;
cparams.yarn_ext_factor = params.yarn_ext_factor;
cparams.yarn_attn_factor = params.yarn_attn_factor;
cparams.yarn_beta_fast = params.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.mul_mat_q = params.mul_mat_q; cparams.mul_mat_q = params.mul_mat_q;
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
cparams.n_yarn_orig_ctx = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx :
hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx :
hparams.n_ctx_train;
auto rope_scaling_type = params.rope_scaling_type;
if (rope_scaling_type == LLAMA_ROPE_SCALING_UNSPECIFIED) {
rope_scaling_type = hparams.rope_scaling_type_train;
}
if (rope_scaling_type == LLAMA_ROPE_SCALING_NONE) {
cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none
}
if (std::isnan(cparams.yarn_ext_factor)) { // NaN indicates 'not set'
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_YARN ? 1.0f : 0.0f;
}
if (params.seed == LLAMA_DEFAULT_SEED) { if (params.seed == LLAMA_DEFAULT_SEED) {
params.seed = time(NULL); params.seed = time(NULL);
} }

14
llama.h
View File

@ -106,6 +106,14 @@ extern "C" {
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
}; };
enum llama_rope_scaling_type {
LLAMA_ROPE_SCALING_UNSPECIFIED = -1,
LLAMA_ROPE_SCALING_NONE = 0,
LLAMA_ROPE_SCALING_LINEAR = 1,
LLAMA_ROPE_SCALING_YARN = 2,
LLAMA_ROPE_SCALING_MAX_VALUE = LLAMA_ROPE_SCALING_YARN,
};
typedef struct llama_token_data { typedef struct llama_token_data {
llama_token id; // token id llama_token id; // token id
float logit; // log-odds of the token float logit; // log-odds of the token
@ -172,10 +180,16 @@ extern "C" {
uint32_t n_batch; // prompt processing maximum batch size uint32_t n_batch; // prompt processing maximum batch size
uint32_t n_threads; // number of threads to use for generation uint32_t n_threads; // number of threads to use for generation
uint32_t n_threads_batch; // number of threads to use for batch processing uint32_t n_threads_batch; // number of threads to use for batch processing
int8_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
// ref: https://github.com/ggerganov/llama.cpp/pull/2054 // ref: https://github.com/ggerganov/llama.cpp/pull/2054
float rope_freq_base; // RoPE base frequency, 0 = from model float rope_freq_base; // RoPE base frequency, 0 = from model
float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
float yarn_ext_factor; // YaRN extrapolation mix factor, NaN = from model
float yarn_attn_factor; // YaRN magnitude scaling factor
float yarn_beta_fast; // YaRN low correction dim
float yarn_beta_slow; // YaRN high correction dim
uint32_t yarn_orig_ctx; // YaRN original context size
// Keep the booleans together to avoid misalignment during copy-by-value. // Keep the booleans together to avoid misalignment during copy-by-value.
bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true) bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true)