gguf : add rope_freq_base parameter for CodeLlama (#2769)

This commit is contained in:
slaren 2023-08-24 20:04:05 +02:00 committed by GitHub
parent 01f2224682
commit 0d3094f0c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 48 additions and 21 deletions

View File

@ -104,6 +104,8 @@ class Params:
n_head_kv: int
f_norm_eps: float
f_rope_freq_base: Optional[float] = None
ftype: Optional[GGMLFileType] = None
# path to the directory containing the model files
@ -194,15 +196,16 @@ class Params:
def loadOriginalParamsJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
config = json.load(open(config_path))
n_vocab = config["vocab_size"] if "vocab_size" in config else -1
n_embd = config["dim"]
n_layer = config["n_layers"]
n_mult = config["multiple_of"]
n_ctx = 2048 if config["norm_eps"] == 1e-06 else 4096 # hack to determine LLaMA v1 vs v2
n_ff = -1
n_head = config["n_heads"]
n_head_kv = config["n_kv_heads"] if "n_kv_heads" in config else n_head
f_norm_eps = config["norm_eps"]
n_vocab = config["vocab_size"] if "vocab_size" in config else -1
n_embd = config["dim"]
n_layer = config["n_layers"]
n_mult = config["multiple_of"]
n_ctx = 2048 if config["norm_eps"] == 1e-06 else 4096 # hack to determine LLaMA v1 vs v2
n_ff = -1
n_head = config["n_heads"]
n_head_kv = config["n_kv_heads"] if "n_kv_heads" in config else n_head
f_norm_eps = config["norm_eps"]
f_rope_freq_base = config["rope_theta"] if "rope_theta" in config else None
if n_vocab == -1:
n_vocab = model["tok_embeddings.weight"].shape[0]
@ -211,15 +214,16 @@ class Params:
n_ff = model["layers.0.feed_forward.w1.weight"].shape[0]
return Params(
n_vocab = n_vocab,
n_embd = n_embd,
n_mult = n_mult,
n_layer = n_layer,
n_ctx = n_ctx,
n_ff = n_ff,
n_head = n_head,
n_head_kv = n_head_kv,
f_norm_eps = f_norm_eps,
n_vocab = n_vocab,
n_embd = n_embd,
n_mult = n_mult,
n_layer = n_layer,
n_ctx = n_ctx,
n_ff = n_ff,
n_head = n_head,
n_head_kv = n_head_kv,
f_norm_eps = f_norm_eps,
f_rope_freq_base = f_rope_freq_base,
)
@staticmethod
@ -754,6 +758,9 @@ class OutputFile:
self.gguf.add_head_count_kv (params.n_head_kv)
self.gguf.add_layer_norm_rms_eps (params.f_norm_eps)
if params.f_rope_freq_base:
self.gguf.add_rope_freq_base(params.f_rope_freq_base)
if params.ftype:
self.gguf.add_file_type(params.ftype)

View File

@ -47,6 +47,7 @@ KEY_ATTENTION_LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
# RoPE
KEY_ROPE_DIMENSION_COUNT = "{arch}.rope.dimension_count"
KEY_ROPE_FREQ_BASE = "{arch}.rope.freq_base"
KEY_ROPE_SCALE_LINEAR = "{arch}.rope.scale_linear"
# tokenization
@ -663,7 +664,10 @@ class GGUFWriter:
self.add_uint32(
KEY_ROPE_DIMENSION_COUNT.format(arch=self.arch), count)
def add_rope_scale_linear(self, value: float):
def add_rope_freq_base(self, value: float):
self.add_float32(KEY_ROPE_FREQ_BASE.format(arch=self.arch), value)
def add_rope_scale_linear(self, value: float):
self.add_float32(KEY_ROPE_SCALE_LINEAR.format(arch=self.arch), value)
def add_tokenizer_model(self, model: str):

View File

@ -195,6 +195,7 @@ enum llm_kv {
LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,
LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_FREQ_BASE,
LLM_KV_ROPE_SCALE_LINEAR,
LLM_KV_TOKENIZER_MODEL,
@ -238,6 +239,7 @@ static std::map<llm_kv, std::string> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
@ -1561,12 +1563,26 @@ static void llm_load_hparams(
hparams.n_head_kv = hparams.n_head;
GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV));
// TODO: manually setting rope scale should override this
// TODO: manually setting rope freq base and scale should override this
// FIXME: partial fix when the param specified is not the default value, but
// will not work for overriding the model value to the params default
llama_context_params defaults = llama_context_default_params();
// rope_freq_base
{
float ropebase = 10000.0f;
GGUF_GET_KEY(ctx, ropebase, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
if (ropebase != 10000.0f && rope_freq_base == defaults.rope_freq_base) {
rope_freq_base = ropebase;
}
}
// rope_freq_scale (inverse of the kv) is optional
{
float ropescale = 1.0f;
GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
if (ropescale != 1.0f) {
if (ropescale != 1.0f && rope_freq_scale == defaults.rope_freq_scale) {
rope_freq_scale = 1.0f/ropescale;
}
}