mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-09-22 21:16:20 +00:00
convert : adapt MiniCPM3 to separate rope_freqs insertion
MiniCPM3's tokenizer is treated as a SentencePiece tokenizer to avoid having to run its custom Python code which mixes tokenization in the same file as tool calls. gguf-py : add long and short RoPE factors to tensor mappings Empty, but the key names are used to populate the mappings.
This commit is contained in:
parent
ed0f2c4ab1
commit
e83d2707d3
@ -1862,8 +1862,6 @@ class MiniCPM3Model(Model):
|
|||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
hparams = self.hparams
|
hparams = self.hparams
|
||||||
|
|
||||||
rope_dims = hparams["qk_rope_head_dim"]
|
|
||||||
|
|
||||||
self.gguf_writer.add_file_type(self.ftype)
|
self.gguf_writer.add_file_type(self.ftype)
|
||||||
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
|
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
|
||||||
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||||
@ -1879,24 +1877,25 @@ class MiniCPM3Model(Model):
|
|||||||
self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
|
self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
|
||||||
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
|
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
|
||||||
|
|
||||||
|
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||||
rope_scaling = self.find_hparam(['rope_scaling'], True)
|
rope_scaling = self.find_hparam(['rope_scaling'], True)
|
||||||
if rope_scaling is None:
|
if rope_scaling is not None:
|
||||||
return
|
rope_dims = self.hparams["qk_rope_head_dim"]
|
||||||
|
|
||||||
long_factors = rope_scaling.get('long_factor', None)
|
long_factors = rope_scaling.get('long_factor', None)
|
||||||
short_factors = rope_scaling.get('short_factor', None)
|
short_factors = rope_scaling.get('short_factor', None)
|
||||||
|
|
||||||
if long_factors is None or short_factors is None:
|
if long_factors is None or short_factors is None:
|
||||||
raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor')
|
raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor')
|
||||||
|
|
||||||
if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2:
|
if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2:
|
||||||
raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}')
|
raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}')
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32))
|
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32))
|
||||||
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32))
|
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32))
|
||||||
|
|
||||||
def set_vocab(self):
|
def set_vocab(self):
|
||||||
self._set_vocab_llama_hf()
|
self._set_vocab_sentencepiece()
|
||||||
|
|
||||||
def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
|
def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
|
||||||
if n_kv_head is not None and n_head != n_kv_head:
|
if n_kv_head is not None and n_head != n_kv_head:
|
||||||
|
@ -877,6 +877,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||||||
MODEL_TENSOR.TOKEN_EMBD,
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
MODEL_TENSOR.OUTPUT_NORM,
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
MODEL_TENSOR.OUTPUT,
|
MODEL_TENSOR.OUTPUT,
|
||||||
|
MODEL_TENSOR.ROPE_FACTORS_LONG,
|
||||||
|
MODEL_TENSOR.ROPE_FACTORS_SHORT,
|
||||||
MODEL_TENSOR.ATTN_NORM,
|
MODEL_TENSOR.ATTN_NORM,
|
||||||
MODEL_TENSOR.ATTN_Q_A,
|
MODEL_TENSOR.ATTN_Q_A,
|
||||||
MODEL_TENSOR.ATTN_Q_B,
|
MODEL_TENSOR.ATTN_Q_B,
|
||||||
|
@ -87,6 +87,9 @@ class TensorNameMap:
|
|||||||
"rope.freqs", # llama-pth
|
"rope.freqs", # llama-pth
|
||||||
"rotary_pos_emb.inv_freq", # chatglm
|
"rotary_pos_emb.inv_freq", # chatglm
|
||||||
),
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.ROPE_FACTORS_LONG: (),
|
||||||
|
MODEL_TENSOR.ROPE_FACTORS_SHORT: (),
|
||||||
}
|
}
|
||||||
|
|
||||||
block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
|
block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
|
||||||
|
Loading…
Reference in New Issue
Block a user