mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 02:44:36 +00:00
convert-hf : fix exception in sentencepiece with added tokens (#6320)
This commit is contained in:
parent
d25b1c31b0
commit
e097633f63
@ -331,7 +331,7 @@ class Model(ABC):
|
|||||||
tokenizer = SentencePieceProcessor(str(tokenizer_path))
|
tokenizer = SentencePieceProcessor(str(tokenizer_path))
|
||||||
vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
|
vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
|
||||||
|
|
||||||
for token_id in range(vocab_size):
|
for token_id in range(tokenizer.vocab_size()):
|
||||||
piece = tokenizer.id_to_piece(token_id)
|
piece = tokenizer.id_to_piece(token_id)
|
||||||
text = piece.encode("utf-8")
|
text = piece.encode("utf-8")
|
||||||
score = tokenizer.get_score(token_id)
|
score = tokenizer.get_score(token_id)
|
||||||
@ -356,9 +356,13 @@ class Model(ABC):
|
|||||||
added_tokens_json = json.load(f)
|
added_tokens_json = json.load(f)
|
||||||
|
|
||||||
for key in added_tokens_json:
|
for key in added_tokens_json:
|
||||||
tokens.append(key.encode("utf-8"))
|
key = key.encode("utf-8")
|
||||||
scores.append(-1000.0)
|
if key not in tokens:
|
||||||
toktypes.append(SentencePieceTokenTypes.USER_DEFINED)
|
tokens.append(key)
|
||||||
|
scores.append(-1000.0)
|
||||||
|
toktypes.append(SentencePieceTokenTypes.USER_DEFINED)
|
||||||
|
|
||||||
|
assert len(tokens) == vocab_size
|
||||||
|
|
||||||
self.gguf_writer.add_tokenizer_model("llama")
|
self.gguf_writer.add_tokenizer_model("llama")
|
||||||
self.gguf_writer.add_token_list(tokens)
|
self.gguf_writer.add_token_list(tokens)
|
||||||
|
Loading…
Reference in New Issue
Block a user