py : handle byte tokens in get_token_type (#5341)

* py : handle byte tokens in `get_token_type`

* py : fix empty bytes arg
This commit is contained in:
Georgi Gerganov 2024-02-06 07:47:22 +02:00 committed by GitHub
parent 098f6d737b
commit 906cff55c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -515,10 +515,14 @@ class HfVocab:
# Yield token text, score, and type # Yield token text, score, and type
yield token_text, self.get_token_score(token_id), self.get_token_type( yield token_text, self.get_token_score(token_id), self.get_token_type(
token_id, self.special_ids # Reuse already stored special IDs token_id, token_text, self.special_ids # Reuse already stored special IDs
) )
def get_token_type(self, token_id: int, special_ids: set[int]) -> gguf.TokenType: def get_token_type(self, token_id: int, token_text: bytes, special_ids: set[int]) -> gguf.TokenType:
# Special case for byte tokens
if re.fullmatch(br"<0x[0-9A-Fa-f]{2}>", token_text):
return gguf.TokenType.BYTE
# Determine token type based on whether it's a special token # Determine token type based on whether it's a special token
return gguf.TokenType.CONTROL if token_id in special_ids else gguf.TokenType.NORMAL return gguf.TokenType.CONTROL if token_id in special_ids else gguf.TokenType.NORMAL
@ -530,7 +534,7 @@ class HfVocab:
def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
for text in self.added_tokens_list: for text in self.added_tokens_list:
if text in self.specials: if text in self.specials:
toktype = self.get_token_type(self.specials[text], self.special_ids) toktype = self.get_token_type(self.specials[text], b'', self.special_ids)
score = self.get_token_score(self.specials[text]) score = self.get_token_score(self.specials[text])
else: else:
toktype = gguf.TokenType.USER_DEFINED toktype = gguf.TokenType.USER_DEFINED