Update test-tokenizer-random.py

Added try/except
This commit is contained in:
Robert 2024-11-12 22:16:34 -08:00
parent 54ef9cfc72
commit 60fd27b68d

View File

@ -1,30 +1,38 @@
# Test libllama tokenizer == AutoTokenizer. #!/usr/bin/env python3
# Brute force random words/text generation. """
# Test libllama tokenizer against AutoTokenizer using brute force random words/text generation.
# Sample usage:
# Sample usage:
# python3 tests/test-tokenizer-random.py ./models/ggml-vocab-llama-bpe.gguf ./models/tokenizers/llama-bpe
# python3 test-tokenizer-random.py ./models/ggml-vocab-llama-bpe.gguf ./models/tokenizers/llama-bpe
"""
from __future__ import annotations from __future__ import annotations
#
import time import time
import logging import logging
import argparse import argparse
import shutil
import subprocess import subprocess
import random import random
import unicodedata import unicodedata
from pathlib import Path from pathlib import Path
from typing import Any, Iterator, cast from typing import Any, Iterator, cast
from typing_extensions import Buffer from typing_extensions import Buffer
#
# External Imports
import cffi import cffi
from transformers import AutoTokenizer, PreTrainedTokenizer from transformers import AutoTokenizer, PreTrainedTokenizer
#
####################################################################################################
#
# Classes:
logger = logging.getLogger("test-tokenizer-random") logger = logging.getLogger("test-tokenizer-random")
if shutil.which("gcc") is None:
raise EnvironmentError("GCC is not available on this system. Please install GCC or use preprocessed headers.")
class LibLlama: class LibLlama:
@ -32,6 +40,12 @@ class LibLlama:
DEFAULT_PATH_INCLUDES = ["./ggml/include/", "./include/"] DEFAULT_PATH_INCLUDES = ["./ggml/include/", "./include/"]
DEFAULT_PATH_LIBLLAMA = "./build/src/libllama.so" # CMakeLists.txt: BUILD_SHARED_LIBS ON DEFAULT_PATH_LIBLLAMA = "./build/src/libllama.so" # CMakeLists.txt: BUILD_SHARED_LIBS ON
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.free()
def __init__(self, path_llama_h: str | None = None, path_includes: list[str] = [], path_libllama: str | None = None): def __init__(self, path_llama_h: str | None = None, path_includes: list[str] = [], path_libllama: str | None = None):
path_llama_h = path_llama_h or self.DEFAULT_PATH_LLAMA_H path_llama_h = path_llama_h or self.DEFAULT_PATH_LLAMA_H
path_includes = path_includes or self.DEFAULT_PATH_INCLUDES path_includes = path_includes or self.DEFAULT_PATH_INCLUDES
@ -408,14 +422,17 @@ def generator_random_vocab_words(tokenizer: TokenizerGroundtruth, iterations=100
def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]): def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]):
try:
def find_first_mismatch(ids1: list[int] | str, ids2: list[int] | str): # def find_first_mismatch(ids1: list[int] | str, ids2: list[int] | str):
for i, (a, b) in enumerate(zip(ids1, ids2)): # for i, (a, b) in enumerate(zip(ids1, ids2)):
if a != b: # if a != b:
return i # return i
if len(ids1) == len(ids2): # if len(ids1) == len(ids2):
return -1 # return -1
return min(len(ids1), len(ids2)) # return min(len(ids1), len(ids2))
# Rewritten to use zip() and next() instead of for loop
def find_first_mismatch(ids1: Sequence[Any], ids2: Sequence[Any]) -> int:
return next((i for i, (a, b) in enumerate(zip(ids1, ids2)) if a != b), -1)
def check_detokenizer(text: str, text1: str, text2: str) -> bool: def check_detokenizer(text: str, text1: str, text2: str) -> bool:
if text1 == text2: # equal to TokenizerGroundtruth? if text1 == text2: # equal to TokenizerGroundtruth?
@ -478,13 +495,17 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
t_total = time.perf_counter() - t_start t_total = time.perf_counter() - t_start
logger.info(f"{generator.__qualname__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}") logger.info(f"{generator.__qualname__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}")
except Exception as e:
logger.exception(f"An error occurred during tokenizer comparison: {e}")
def main(argv: list[str] | None = None): def main(argv: list[str] | None = None):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("vocab_file", type=str, help="path to vocab 'gguf' file") parser.add_argument("vocab_file", type=str, help="path to vocab 'gguf' file")
parser.add_argument("dir_tokenizer", type=str, help="directory containing 'tokenizer.model' file") parser.add_argument("dir_tokenizer", type=str, help="directory containing 'tokenizer.model' file")
parser.add_argument("--verbose", action="store_true", help="increase output verbosity") parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
parser.add_argument("--max-errors", type=int, default=10, help="Maximum number of errors before stopping")
parser.add_argument("--iterations", type=int, default=100, help="Number of iterations for random generators")
parser.add_argument("--tokenizers", nargs="+", help="List of tokenizers to test", default=tokenizers)
args = parser.parse_args(argv) args = parser.parse_args(argv)
logging.basicConfig(level = logging.DEBUG if args.verbose else logging.INFO) logging.basicConfig(level = logging.DEBUG if args.verbose else logging.INFO)