mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-13 14:29:52 +00:00
convert : refactor rope_freqs generation (#9396)
* convert : refactor rope_freqs generation This should also fix vocab-only conversion for Phi-3. * convert : adapt MiniCPM3 to separate rope_freqs insertion MiniCPM3's tokenizer is treated as a SentencePiece tokenizer to avoid having to run its custom Python code which mixes tokenization in the same file as tool calls. gguf-py : add long and short RoPE factors to tensor mappings Empty, but the key names are used to populate the mappings.
This commit is contained in:
parent
6f1d9d71f4
commit
1927378bcc
@ -15,6 +15,7 @@ from enum import IntEnum
|
||||
from pathlib import Path
|
||||
from hashlib import sha256
|
||||
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast
|
||||
from itertools import chain
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
@ -64,7 +65,6 @@ class Model:
|
||||
model_name: str | None
|
||||
metadata_override: Path | None
|
||||
dir_model_card: Path
|
||||
is_lora: bool
|
||||
|
||||
# subclasses should define this!
|
||||
model_arch: gguf.MODEL_ARCH
|
||||
@ -72,7 +72,7 @@ class Model:
|
||||
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False,
|
||||
use_temp_file: bool = False, eager: bool = False,
|
||||
metadata_override: Path | None = None, model_name: str | None = None,
|
||||
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False, is_lora: bool = False):
|
||||
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
|
||||
if type(self) is Model:
|
||||
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
|
||||
|
||||
@ -94,7 +94,6 @@ class Model:
|
||||
self.metadata_override = metadata_override
|
||||
self.model_name = model_name
|
||||
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
|
||||
self.is_lora = is_lora # true if model is used inside convert_lora_to_gguf.py
|
||||
|
||||
# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
|
||||
if self.ftype == gguf.LlamaFileType.GUESSED:
|
||||
@ -270,10 +269,14 @@ class Model:
|
||||
|
||||
return False
|
||||
|
||||
# some models need extra generated tensors (like rope_freqs)
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
return ()
|
||||
|
||||
def prepare_tensors(self):
|
||||
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
|
||||
|
||||
for name, data_torch in self.get_tensors():
|
||||
for name, data_torch in chain(self.generate_extra_tensors(), self.get_tensors()):
|
||||
# we don't need these
|
||||
if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
|
||||
continue
|
||||
@ -1617,7 +1620,7 @@ class LlamaModel(Model):
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
def prepare_tensors(self):
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
|
||||
if rope_scaling.get("rope_type", '').lower() == "llama3":
|
||||
base = self.hparams.get("rope_theta", 10000.0)
|
||||
@ -1644,9 +1647,9 @@ class LlamaModel(Model):
|
||||
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
||||
rope_factors.append(1 / ((1 - smooth) / factor + smooth))
|
||||
|
||||
if not self.is_lora:
|
||||
self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32))
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32))
|
||||
|
||||
def prepare_tensors(self):
|
||||
super().prepare_tensors()
|
||||
|
||||
if self._experts is not None:
|
||||
@ -1870,8 +1873,6 @@ class MiniCPM3Model(Model):
|
||||
def set_gguf_parameters(self):
|
||||
hparams = self.hparams
|
||||
|
||||
rope_dims = hparams["qk_rope_head_dim"]
|
||||
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
|
||||
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||
@ -1887,24 +1888,25 @@ class MiniCPM3Model(Model):
|
||||
self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
|
||||
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
|
||||
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
rope_scaling = self.find_hparam(['rope_scaling'], True)
|
||||
if rope_scaling is None:
|
||||
return
|
||||
if rope_scaling is not None:
|
||||
rope_dims = self.hparams["qk_rope_head_dim"]
|
||||
|
||||
long_factors = rope_scaling.get('long_factor', None)
|
||||
short_factors = rope_scaling.get('short_factor', None)
|
||||
long_factors = rope_scaling.get('long_factor', None)
|
||||
short_factors = rope_scaling.get('short_factor', None)
|
||||
|
||||
if long_factors is None or short_factors is None:
|
||||
raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor')
|
||||
if long_factors is None or short_factors is None:
|
||||
raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor')
|
||||
|
||||
if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2:
|
||||
raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}')
|
||||
if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2:
|
||||
raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}')
|
||||
|
||||
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32))
|
||||
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32))
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32))
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32))
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_llama_hf()
|
||||
self._set_vocab_sentencepiece()
|
||||
|
||||
def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
|
||||
if n_kv_head is not None and n_head != n_kv_head:
|
||||
@ -2216,6 +2218,13 @@ class Phi3MiniModel(Model):
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
self.gguf_writer.add_sliding_window(self.find_hparam(["sliding_window"]))
|
||||
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
n_embd = self.find_hparam(["hidden_size", "n_embd"])
|
||||
n_head = self.find_hparam(["num_attention_heads", "n_head"])
|
||||
max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"])
|
||||
orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"])
|
||||
rope_dims = n_embd // n_head
|
||||
|
||||
# write rope scaling for long context (128k) model
|
||||
rope_scaling = self.find_hparam(['rope_scaling'], True)
|
||||
if rope_scaling is None:
|
||||
@ -2245,9 +2254,8 @@ class Phi3MiniModel(Model):
|
||||
if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2:
|
||||
raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}')
|
||||
|
||||
if not self.is_lora:
|
||||
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32))
|
||||
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32))
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32))
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32))
|
||||
|
||||
|
||||
@Model.register("PlamoForCausalLM")
|
||||
@ -4071,7 +4079,7 @@ class ExaoneModel(Model):
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
|
||||
self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"])
|
||||
|
||||
def prepare_tensors(self):
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
|
||||
if rope_scaling.get("rope_type", '').lower() == "llama3":
|
||||
base = self.hparams.get("rope_theta", 10000.0)
|
||||
@ -4098,10 +4106,7 @@ class ExaoneModel(Model):
|
||||
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
||||
rope_factors.append(1 / ((1 - smooth) / factor + smooth))
|
||||
|
||||
if not self.is_lora:
|
||||
self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32))
|
||||
|
||||
super().prepare_tensors()
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32))
|
||||
|
||||
|
||||
@Model.register("GraniteForCausalLM")
|
||||
|
@ -331,6 +331,10 @@ if __name__ == '__main__':
|
||||
self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha)
|
||||
super().set_gguf_parameters()
|
||||
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
# Never add extra tensors (e.g. rope_freqs) for LoRA adapters
|
||||
return ()
|
||||
|
||||
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
||||
tensor_map: dict[str, PartialLoraTensor] = {}
|
||||
|
||||
@ -392,7 +396,6 @@ if __name__ == '__main__':
|
||||
dry_run=args.dry_run,
|
||||
dir_lora_model=dir_lora,
|
||||
lora_alpha=alpha,
|
||||
is_lora=True,
|
||||
)
|
||||
|
||||
logger.info("Exporting model...")
|
||||
|
@ -814,6 +814,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FACTORS_LONG,
|
||||
MODEL_TENSOR.ROPE_FACTORS_SHORT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_QKV,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
@ -892,6 +894,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FACTORS_LONG,
|
||||
MODEL_TENSOR.ROPE_FACTORS_SHORT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q_A,
|
||||
MODEL_TENSOR.ATTN_Q_B,
|
||||
|
@ -87,6 +87,9 @@ class TensorNameMap:
|
||||
"rope.freqs", # llama-pth
|
||||
"rotary_pos_emb.inv_freq", # chatglm
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ROPE_FACTORS_LONG: (),
|
||||
MODEL_TENSOR.ROPE_FACTORS_SHORT: (),
|
||||
}
|
||||
|
||||
block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
|
||||
|
Loading…
Reference in New Issue
Block a user