diff --git a/.devops/nix/package.nix b/.devops/nix/package.nix index 49e9b7528..00be596ce 100644 --- a/.devops/nix/package.nix +++ b/.devops/nix/package.nix @@ -89,6 +89,22 @@ let ps.tiktoken ps.torchWithoutCuda ps.transformers + + # server bench + ps.matplotlib + + # server tests + ps.openai + ps.behave + ps.prometheus-client + + # for examples/pydantic-models-to-grammar-examples.py + ps.docstring-parser + ps.pydantic + + # for scripts/compare-llama-bench.py + ps.gitpython + ps.tabulate ] ); diff --git a/.github/workflows/python-type-check.yml b/.github/workflows/python-type-check.yml new file mode 100644 index 000000000..e5ff5e6d7 --- /dev/null +++ b/.github/workflows/python-type-check.yml @@ -0,0 +1,38 @@ +name: Python Type-Check + +on: + push: + paths: + - '.github/workflows/python-type-check.yml' + - '**.py' + - '**/requirements*.txt' + pull_request: + paths: + - '.github/workflows/python-type-check.yml' + - '**.py' + - '**/requirements*.txt' + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + python-type-check: + runs-on: ubuntu-latest + name: pyright type-check + steps: + - name: Check out source repository + uses: actions/checkout@v4 + - name: Set up Python environment + uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Install Python dependencies + # TODO: use a venv + run: pip install -r requirements/requirements-all.txt + - name: Type-check with Pyright + uses: jakebailey/pyright-action@v2 + with: + version: 1.1.370 + level: warning + warnings: true diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 6ee41d3a1..6cea73f08 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -265,7 +265,7 @@ class Model: break for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)): - data: np.ndarray = data # type hint + data: np.ndarray # type hint n_dims = len(data.shape) data_dtype = data.dtype data_qtype: gguf.GGMLQuantizationType | None = None @@ -599,10 +599,6 @@ class Model: tokenizer_path = self.dir_model / 'tokenizer.model' - tokens: list[bytes] = [] - scores: list[float] = [] - toktypes: list[int] = [] - if not tokenizer_path.is_file(): raise FileNotFoundError(f"File not found: {tokenizer_path}") @@ -2120,7 +2116,7 @@ class InternLM2Model(Model): logger.error(f'Error: Missing {tokenizer_path}') sys.exit(1) - sentencepiece_model = model.ModelProto() + sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix @@ -2972,16 +2968,16 @@ class T5Model(Model): if not tokenizer_path.is_file(): raise FileNotFoundError(f"File not found: {tokenizer_path}") - sentencepiece_model = model.ModelProto() + sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) # some models like Pile-T5 family use BPE tokenizer instead of Unigram - if sentencepiece_model.trainer_spec.model_type == 2: # BPE + if sentencepiece_model.trainer_spec.model_type == 2: # BPE # assure the tokenizer model file name is correct assert tokenizer_path.name == 'tokenizer.model' return self._set_vocab_sentencepiece() else: - assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM + assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces @@ -3152,7 +3148,7 @@ class JaisModel(Model): # but Jais's PyTorch model simply precalculates the slope values and places them # in relative_pes.slopes n_head_closest_log2 = 2 ** math.floor(math.log2(self.hparams["n_head"])) - first_val = float(data_torch._data[0]) + first_val = float(data_torch[0].item()) self.max_alibi_bias = -round(math.log2(first_val) * n_head_closest_log2) return tensors @@ -3186,7 +3182,7 @@ class ChatGLMModel(Model): def set_vocab_chatglm3(self): dir_model = self.dir_model hparams = self.hparams - tokens: list[bytearray] = [] + tokens: list[bytes] = [] toktypes: list[int] = [] scores: list[float] = [] @@ -3335,7 +3331,7 @@ class ChatGLMModel(Model): special_vocab.add_to_gguf(self.gguf_writer) def set_gguf_parameters(self): - self.gguf_writer.add_name(self.hparams.get("_name_or_path").split("/")[1]) # THUDM/glm4-9b-chat or THUDM/chatglm3-6b + self.gguf_writer.add_name(self.hparams["_name_or_path"].split("/")[1]) # THUDM/glm4-9b-chat or THUDM/chatglm3-6b n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed")) n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads")) n_head_kv = self.hparams.get("multi_query_group_num", n_head) diff --git a/convert_llama_ggml_to_gguf.py b/convert_llama_ggml_to_gguf.py index 9349de3b3..95ea831a5 100755 --- a/convert_llama_ggml_to_gguf.py +++ b/convert_llama_ggml_to_gguf.py @@ -354,7 +354,8 @@ class GGMLToGGUF: def handle_metadata(cfg, hp): - import convert + import examples.convert_legacy_llama as convert + assert cfg.model_metadata_dir.is_dir(), 'Metadata dir is not a directory' hf_config_path = cfg.model_metadata_dir / "config.json" orig_config_path = cfg.model_metadata_dir / "params.json" diff --git a/examples/convert_legacy_llama.py b/examples/convert_legacy_llama.py index 721a57c00..c2c73e8ad 100755 --- a/examples/convert_legacy_llama.py +++ b/examples/convert_legacy_llama.py @@ -353,7 +353,7 @@ class Metadata: version: Optional[str] = None url: Optional[str] = None description: Optional[str] = None - licence: Optional[str] = None + license: Optional[str] = None source_url: Optional[str] = None source_hf_repo: Optional[str] = None @@ -492,12 +492,13 @@ class LazyTensor: LazyModel: TypeAlias = 'dict[str, LazyTensor]' +ModelFormat: TypeAlias = Literal['ggml', 'torch', 'safetensors', 'none'] @dataclass class ModelPlus: model: LazyModel paths: list[Path] # Where this was read from. - format: Literal['ggml', 'torch', 'safetensors', 'none'] + format: ModelFormat vocab: BaseVocab | None # For GGML models (which have vocab built in), the vocab. @@ -536,7 +537,7 @@ def merge_sharded(models: list[LazyModel]) -> LazyModel: def merge_multifile_models(models_plus: list[ModelPlus]) -> ModelPlus: - formats = set(mp.format for mp in models_plus) + formats: set[ModelFormat] = set(mp.format for mp in models_plus) assert len(formats) == 1, "different formats?" format = formats.pop() paths = [path for mp in models_plus for path in mp.paths] @@ -555,7 +556,7 @@ def merge_multifile_models(models_plus: list[ModelPlus]) -> ModelPlus: else: model = merge_sharded([mp.model for mp in models_plus]) - return ModelPlus(model, paths, format, vocab) # pytype: disable=wrong-arg-types + return ModelPlus(model, paths, format, vocab) def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_head_kv: int) -> LazyTensor: @@ -805,7 +806,7 @@ class OutputFile: def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE): self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess) - def add_meta_model(self, params: Params, metadata: Metadata) -> None: + def add_meta_model(self, params: Params, metadata: Metadata | None) -> None: # Metadata About The Model And Its Provenence name = "LLaMA" if metadata is not None and metadata.name is not None: @@ -827,8 +828,8 @@ class OutputFile: self.gguf.add_url(metadata.url) if metadata.description is not None: self.gguf.add_description(metadata.description) - if metadata.licence is not None: - self.gguf.add_licence(metadata.licence) + if metadata.license is not None: + self.gguf.add_licence(metadata.license) if metadata.source_url is not None: self.gguf.add_source_url(metadata.source_url) if metadata.source_hf_repo is not None: @@ -943,7 +944,7 @@ class OutputFile: @staticmethod def write_vocab_only( fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab, - endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, metadata: Metadata = None, + endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, metadata: Metadata | None = None, ) -> None: check_vocab_size(params, vocab, pad_vocab=pad_vocab) @@ -977,7 +978,7 @@ class OutputFile: fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab, concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, - metadata: Metadata = None, + metadata: Metadata | None = None, ) -> None: check_vocab_size(params, vocab, pad_vocab=pad_vocab) @@ -1396,6 +1397,8 @@ def main(args_in: list[str] | None = None) -> None: if model_plus.vocab is not None and args.vocab_dir is None and not args.no_vocab: vocab = model_plus.vocab + assert params is not None + logger.info(f"Vocab info: {vocab}") logger.info(f"Special vocab info: {special_vocab}") model = model_plus.model diff --git a/examples/finetune/convert_finetune_checkpoint_to_gguf.py b/examples/finetune/convert_finetune_checkpoint_to_gguf.py index c89090918..1b79d6995 100644 --- a/examples/finetune/convert_finetune_checkpoint_to_gguf.py +++ b/examples/finetune/convert_finetune_checkpoint_to_gguf.py @@ -74,7 +74,7 @@ class Tensor: if len(self.ne) == 0: self.nbytes = 0 else: - self.nbytes = int(np.product(self.ne)) * 4 + self.nbytes = int(np.prod(self.ne)) * 4 else: raise ValueError(f"Unhandled data type '{self.dtype}'") diff --git a/examples/json_schema_pydantic_example.py b/examples/json_schema_pydantic_example.py index c7ca7b8d9..19c0bdb5b 100644 --- a/examples/json_schema_pydantic_example.py +++ b/examples/json_schema_pydantic_example.py @@ -3,7 +3,7 @@ #! pip install pydantic #! python json_schema_pydantic_example.py -from pydantic import BaseModel, Extra, TypeAdapter +from pydantic import BaseModel, Field, TypeAdapter from annotated_types import MinLen from typing import Annotated, List, Optional import json, requests @@ -17,6 +17,9 @@ if True: The response_model param takes a type (+ supports Pydantic) and behaves just as w/ Instructor (see below) ''' + response_format = None + type_adapter = None + if response_model: type_adapter = TypeAdapter(response_model) schema = type_adapter.json_schema() diff --git a/examples/json_schema_to_grammar.py b/examples/json_schema_to_grammar.py index 072a230f7..a8779bf3b 100755 --- a/examples/json_schema_to_grammar.py +++ b/examples/json_schema_to_grammar.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +from __future__ import annotations + import argparse import itertools import json @@ -188,7 +190,7 @@ def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], ou raise RuntimeError("At least one of min_value or max_value must be set") class BuiltinRule: - def __init__(self, content: str, deps: list = None): + def __init__(self, content: str, deps: list | None = None): self.content = content self.deps = deps or [] @@ -248,7 +250,7 @@ class SchemaConverter: def _format_literal(self, literal): escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub( - lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal + lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)) or m.group(0), literal ) return f'"{escaped}"' @@ -403,11 +405,11 @@ class SchemaConverter: i = 0 length = len(pattern) - def to_rule(s: Tuple[str, bool]) -> str: + def to_rule(s: tuple[str, bool]) -> str: (txt, is_literal) = s return "\"" + txt + "\"" if is_literal else txt - def transform() -> Tuple[str, bool]: + def transform() -> tuple[str, bool]: ''' Parse a unit at index i (advancing it), and return its string representation + whether it's a literal. ''' @@ -420,7 +422,7 @@ class SchemaConverter: # We only need a flat structure here to apply repetition operators to the last item, and # to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially # (GBNF's syntax is luckily very close to regular expressions!) - seq: list[Tuple[str, bool]] = [] + seq: list[tuple[str, bool]] = [] def get_dot(): if self._dotall: diff --git a/examples/llava/convert_image_encoder_to_gguf.py b/examples/llava/convert_image_encoder_to_gguf.py index b00bf7c6d..36f6b92fb 100644 --- a/examples/llava/convert_image_encoder_to_gguf.py +++ b/examples/llava/convert_image_encoder_to_gguf.py @@ -185,6 +185,8 @@ else: fout.add_description("two-tower CLIP model") if has_text_encoder: + assert t_hparams is not None + assert tokens is not None # text_model hparams fout.add_uint32(k(KEY_CONTEXT_LENGTH, TEXT), t_hparams["max_position_embeddings"]) fout.add_uint32(k(KEY_EMBEDDING_LENGTH, TEXT), t_hparams["hidden_size"]) @@ -259,8 +261,8 @@ if has_vision_encoder: if processor is not None: - image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean - image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std + image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean # pyright: ignore[reportAttributeAccessIssue] + image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std # pyright: ignore[reportAttributeAccessIssue] else: image_mean = args.image_mean if args.image_mean is not None else default_image_mean image_std = args.image_std if args.image_std is not None else default_image_std @@ -272,7 +274,7 @@ fout.add_bool("clip.use_gelu", use_gelu) if has_llava_projector: - model.vision_model.encoder.layers.pop(-1) + model.vision_model.encoder.layers.pop(-1) # pyright: ignore[reportAttributeAccessIssue] projector = torch.load(args.llava_projector) for name, data in projector.items(): name = get_tensor_name(name) @@ -286,7 +288,7 @@ if has_llava_projector: print("Projector tensors added\n") -state_dict = model.state_dict() +state_dict = model.state_dict() # pyright: ignore[reportAttributeAccessIssue] for name, data in state_dict.items(): if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_llava_projector): # we don't need this diff --git a/examples/llava/llava_surgery_v2.py b/examples/llava/llava_surgery_v2.py index eb56d6988..2d5b32fe6 100644 --- a/examples/llava/llava_surgery_v2.py +++ b/examples/llava/llava_surgery_v2.py @@ -2,7 +2,9 @@ import argparse import glob import os import torch -from safetensors.torch import load as safe_load, save as safe_save, safe_open, save_file +from safetensors import safe_open +from safetensors.torch import save_file +from typing import Any, ContextManager, cast # Function to determine if file is a SafeTensor file def is_safetensor_file(file_path): @@ -13,7 +15,7 @@ def is_safetensor_file(file_path): def load_model(file_path): if is_safetensor_file(file_path): tensors = {} - with safe_open(file_path, framework="pt", device="cpu") as f: + with cast(ContextManager[Any], safe_open(file_path, framework="pt", device="cpu")) as f: for key in f.keys(): tensors[key] = f.get_tensor(key).clone() # output shape @@ -134,7 +136,7 @@ if len(mm_tensors) == 0: if last_checkpoint is not None: for k, v in last_checkpoint.items(): print(k) - print(f"Found {len(mm_tensors)} tensors to extract out of {len(last_checkpoint)} tensors.") + print(f"Found {len(mm_tensors)} tensors to extract out of {len(last_checkpoint) if last_checkpoint is not None else 0} tensors.") print("No tensors found. Is this a LLaVA model?") exit() @@ -143,8 +145,10 @@ print(f"Found additional {len(first_mm_tensors)} tensors to extract.") # projector = {name: checkpoint.[name].float() for name in mm_tensors} projector = {} for name in mm_tensors: + assert last_checkpoint is not None projector[name] = last_checkpoint[name].float() for name in first_mm_tensors: + assert first_checkpoint is not None projector[name] = first_checkpoint[name].float() if len(projector) > 0: diff --git a/examples/pydantic_models_to_grammar.py b/examples/pydantic_models_to_grammar.py index f029c73a2..d8145710c 100644 --- a/examples/pydantic_models_to_grammar.py +++ b/examples/pydantic_models_to_grammar.py @@ -6,10 +6,10 @@ import re from copy import copy from enum import Enum from inspect import getdoc, isclass -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin, get_type_hints +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin from docstring_parser import parse -from pydantic import BaseModel, Field, create_model +from pydantic import BaseModel, create_model if TYPE_CHECKING: from types import GenericAlias @@ -17,6 +17,9 @@ else: # python 3.8 compat from typing import _GenericAlias as GenericAlias +# TODO: fix this +# pyright: reportAttributeAccessIssue=information + class PydanticDataType(Enum): """ @@ -234,8 +237,9 @@ def generate_gbnf_float_rules(max_digit=None, min_digit=None, max_precision=None # Define the integer part rule integer_part_rule = ( - "integer-part" + (f"-max{max_digit}" if max_digit is not None else "") + ( - f"-min{min_digit}" if min_digit is not None else "") + "integer-part" + + (f"-max{max_digit}" if max_digit is not None else "") + + (f"-min{min_digit}" if min_digit is not None else "") ) # Define the fractional part rule based on precision constraints @@ -458,7 +462,7 @@ def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[Bas if not issubclass(model, BaseModel): # For non-Pydantic classes, generate model_fields from __annotations__ or __init__ if hasattr(model, "__annotations__") and model.__annotations__: - model_fields = {name: (typ, ...) for name, typ in model.__annotations__.items()} + model_fields = {name: (typ, ...) for name, typ in model.__annotations__.items()} # pyright: ignore[reportGeneralTypeIssues] else: init_signature = inspect.signature(model.__init__) parameters = init_signature.parameters @@ -680,7 +684,7 @@ def generate_markdown_documentation( str: Generated text documentation. """ documentation = "" - pyd_models = [(model, True) for model in pydantic_models] + pyd_models: list[tuple[type[BaseModel], bool]] = [(model, True) for model in pydantic_models] for model, add_prefix in pyd_models: if add_prefix: documentation += f"{model_prefix}: {model.__name__}\n" @@ -700,7 +704,7 @@ def generate_markdown_documentation( # Indenting the fields section documentation += f" {fields_prefix}:\n" else: - documentation += f" Fields:\n" + documentation += f" Fields:\n" # noqa: F541 if isclass(model) and issubclass(model, BaseModel): for name, field_type in model.__annotations__.items(): # if name == "markdown_code_block": @@ -778,7 +782,7 @@ def generate_field_markdown( return field_text if field_description != "": - field_text += f" Description: " + field_description + "\n" + field_text += f" Description: {field_description}\n" # Check for and include field-specific examples if available if hasattr(model, "Config") and hasattr(model.Config, @@ -833,7 +837,7 @@ def generate_text_documentation( str: Generated text documentation. """ documentation = "" - pyd_models = [(model, True) for model in pydantic_models] + pyd_models: list[tuple[type[BaseModel], bool]] = [(model, True) for model in pydantic_models] for model, add_prefix in pyd_models: if add_prefix: documentation += f"{model_prefix}: {model.__name__}\n" @@ -1164,7 +1168,7 @@ def create_dynamic_model_from_function(func: Callable[..., Any]): dynamic_fields[param.name] = ( param.annotation if param.annotation != inspect.Parameter.empty else str, default_value) # Creating the dynamic model - dynamic_model = create_model(f"{func.__name__}", **dynamic_fields) # type: ignore[call-overload] + dynamic_model = create_model(f"{func.__name__}", **dynamic_fields) for name, param_doc in param_docs: dynamic_model.model_fields[name].description = param_doc.description @@ -1228,9 +1232,6 @@ def map_grammar_names_to_pydantic_model_class(pydantic_model_list): return output -from enum import Enum - - def json_schema_to_python_types(schema): type_map = { "any": Any, @@ -1275,7 +1276,7 @@ def convert_dictionary_to_pydantic_model(dictionary: dict[str, Any], model_name: if items != {}: array = {"properties": items} array_type = convert_dictionary_to_pydantic_model(array, f"{model_name}_{field_name}_items") - fields[field_name] = (List[array_type], ...) # type: ignore[valid-type] + fields[field_name] = (List[array_type], ...) else: fields[field_name] = (list, ...) elif field_type == "object": @@ -1285,7 +1286,8 @@ def convert_dictionary_to_pydantic_model(dictionary: dict[str, Any], model_name: required = field_data.get("enum", []) for key, field in fields.items(): if key not in required: - fields[key] = (Optional[fields[key][0]], ...) + optional_type = fields[key][0] + fields[key] = (Optional[optional_type], ...) else: field_type = json_schema_to_python_types(field_type) fields[field_name] = (field_type, ...) @@ -1305,6 +1307,7 @@ def convert_dictionary_to_pydantic_model(dictionary: dict[str, Any], model_name: required = dictionary.get("required", []) for key, field in fields.items(): if key not in required: - fields[key] = (Optional[fields[key][0]], ...) + optional_type = fields[key][0] + fields[key] = (Optional[optional_type], ...) custom_model = create_model(model_name, **fields) return custom_model diff --git a/examples/pydantic_models_to_grammar_examples.py b/examples/pydantic_models_to_grammar_examples.py index 160966649..8e7f46cf9 100644 --- a/examples/pydantic_models_to_grammar_examples.py +++ b/examples/pydantic_models_to_grammar_examples.py @@ -1,6 +1,7 @@ # Function calling example using pydantic models. +from __future__ import annotations + import datetime -import importlib import json from enum import Enum from typing import Optional, Union @@ -215,9 +216,9 @@ for call in json_data: if call["function"] == "Calculator": print(Calculator(**call["params"]).run()) elif call["function"] == "get_current_datetime": - print(current_datetime_model(**call["params"]).run()) + print(current_datetime_model(**call["params"]).run()) # pyright: ignore[reportAttributeAccessIssue] elif call["function"] == "get_current_weather": - print(current_weather_tool_model(**call["params"]).run()) + print(current_weather_tool_model(**call["params"]).run()) # pyright: ignore[reportAttributeAccessIssue] # Should output something like this: # 2024-01-14 13:36:06 # {"location": "London", "temperature": "42", "unit": "celsius"} diff --git a/examples/server/bench/bench.py b/examples/server/bench/bench.py index 4fbbb2032..2daac0884 100644 --- a/examples/server/bench/bench.py +++ b/examples/server/bench/bench.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse import json import os @@ -59,10 +61,11 @@ def main(args_in: list[str] | None = None) -> None: sys.exit(1) # start the benchmark + iterations = 0 + data = {} try: start_benchmark(args) - iterations = 0 with open("results.github.env", 'w') as github_env: # parse output with open('k6-results.json', 'r') as bench_results: @@ -129,7 +132,7 @@ def main(args_in: list[str] | None = None) -> None: timestamps, metric_values = zip(*values) metric_values = [float(value) for value in metric_values] prometheus_metrics[metric] = metric_values - timestamps_dt = [datetime.fromtimestamp(int(ts)) for ts in timestamps] + timestamps_dt = [str(datetime.fromtimestamp(int(ts))) for ts in timestamps] plt.figure(figsize=(16, 10), dpi=80) plt.plot(timestamps_dt, metric_values, label=metric) plt.xticks(rotation=0, fontsize=14, horizontalalignment='center', alpha=.7) @@ -156,7 +159,7 @@ def main(args_in: list[str] | None = None) -> None: plt.close() # Mermaid format in case images upload failed - with (open(f"{metric}.mermaid", 'w') as mermaid_f): + with open(f"{metric}.mermaid", 'w') as mermaid_f: mermaid = ( f"""--- config: @@ -278,7 +281,7 @@ def start_server_background(args): } server_process = subprocess.Popen( args, - **pkwargs) + **pkwargs) # pyright: ignore[reportArgumentType, reportCallIssue] def server_log(in_stream, out_stream): for line in iter(in_stream.readline, b''): diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 7b5dabb01..df0814cc9 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -1,5 +1,4 @@ import asyncio -import collections import json import os import re @@ -8,19 +7,23 @@ import subprocess import sys import threading import time +from collections.abc import Sequence from contextlib import closing from re import RegexFlag +from typing import Any, Literal, cast import aiohttp import numpy as np import openai -from behave import step +from openai.types.chat import ChatCompletionChunk +from behave import step # pyright: ignore[reportAttributeAccessIssue] from behave.api.async_step import async_run_until_complete from prometheus_client import parser +# pyright: reportRedeclaration=false @step("a server listening on {server_fqdn}:{server_port}") -def step_server_config(context, server_fqdn, server_port): +def step_server_config(context, server_fqdn: str, server_port: str): context.server_fqdn = server_fqdn context.server_port = int(server_port) context.n_threads = None @@ -74,34 +77,34 @@ def step_server_config(context, server_fqdn, server_port): @step('a model file {hf_file} from HF repo {hf_repo}') -def step_download_hf_model(context, hf_file, hf_repo): +def step_download_hf_model(context, hf_file: str, hf_repo: str): context.model_hf_repo = hf_repo context.model_hf_file = hf_file context.model_file = os.path.basename(hf_file) @step('a model file {model_file}') -def step_model_file(context, model_file): +def step_model_file(context, model_file: str): context.model_file = model_file @step('a model url {model_url}') -def step_model_url(context, model_url): +def step_model_url(context, model_url: str): context.model_url = model_url @step('a model alias {model_alias}') -def step_model_alias(context, model_alias): +def step_model_alias(context, model_alias: str): context.model_alias = model_alias @step('{seed:d} as server seed') -def step_seed(context, seed): +def step_seed(context, seed: int): context.server_seed = seed @step('{ngl:d} GPU offloaded layers') -def step_n_gpu_layer(context, ngl): +def step_n_gpu_layer(context, ngl: int): if 'N_GPU_LAYERS' in os.environ: new_ngl = int(os.environ['N_GPU_LAYERS']) if context.debug: @@ -111,37 +114,37 @@ def step_n_gpu_layer(context, ngl): @step('{n_threads:d} threads') -def step_n_threads(context, n_threads): +def step_n_threads(context, n_threads: int): context.n_thread = n_threads @step('{draft:d} as draft') -def step_draft(context, draft): +def step_draft(context, draft: int): context.draft = draft @step('{n_ctx:d} KV cache size') -def step_n_ctx(context, n_ctx): +def step_n_ctx(context, n_ctx: int): context.n_ctx = n_ctx @step('{n_slots:d} slots') -def step_n_slots(context, n_slots): +def step_n_slots(context, n_slots: int): context.n_slots = n_slots @step('{n_predict:d} server max tokens to predict') -def step_server_n_predict(context, n_predict): +def step_server_n_predict(context, n_predict: int): context.n_server_predict = n_predict @step('{slot_save_path} as slot save path') -def step_slot_save_path(context, slot_save_path): +def step_slot_save_path(context, slot_save_path: str): context.slot_save_path = slot_save_path @step('using slot id {id_slot:d}') -def step_id_slot(context, id_slot): +def step_id_slot(context, id_slot: int): context.id_slot = id_slot @@ -191,7 +194,7 @@ def step_start_server(context): @step("the server is {expecting_status}") @async_run_until_complete -async def step_wait_for_the_server_to_be_started(context, expecting_status): +async def step_wait_for_the_server_to_be_started(context, expecting_status: Literal['healthy', 'ready', 'idle', 'busy'] | str): match expecting_status: case 'healthy': await wait_for_health_status(context, context.base_url, 200, 'ok', @@ -221,7 +224,7 @@ async def step_wait_for_the_server_to_be_started(context, expecting_status): @step('all slots are {expected_slot_status_string}') @async_run_until_complete -async def step_all_slots_status(context, expected_slot_status_string): +async def step_all_slots_status(context, expected_slot_status_string: Literal['idle', 'busy'] | str): match expected_slot_status_string: case 'idle': expected_slot_status = 0 @@ -237,7 +240,7 @@ async def step_all_slots_status(context, expected_slot_status_string): @step('a completion request with {api_error} api error') @async_run_until_complete -async def step_request_completion(context, api_error): +async def step_request_completion(context, api_error: Literal['raised'] | str): expect_api_error = api_error == 'raised' seeds = await completions_seed(context, num_seeds=1) completion = await request_completion(context.prompts.pop(), @@ -777,8 +780,8 @@ def step_assert_metric_value(context, metric_name, metric_value): def step_available_models(context): # openai client always expects an api_key openai.api_key = context.user_api_key if context.user_api_key is not None else 'nope' - openai.api_base = f'{context.base_url}/v1' - context.models = openai.Model.list().data + openai.base_url = f'{context.base_url}/v1/' + context.models = openai.models.list().data @step('{n_model:d} models are supported') @@ -789,7 +792,7 @@ def step_supported_models(context, n_model): @step('model {i_model:d} is {param} {preposition} {param_value}') -def step_supported_models(context, i_model, param, preposition, param_value): +def step_supported_models(context, i_model: int, param: Literal['identified', 'trained'] | str, preposition: str, param_value: str): assert i_model < len(context.models) model = context.models[i_model] @@ -798,7 +801,7 @@ def step_supported_models(context, i_model, param, preposition, param_value): case 'identified': value = model.id case 'trained': - value = str(model.meta.n_ctx_train) + value = str(model.meta["n_ctx_train"]) case _: assert False, "param {param} not supported" assert param_value == value, f"model param {param} {value} != {param_value}" @@ -810,6 +813,7 @@ async def concurrent_requests(context, f_completion, *args, **kwargs): print(f"starting {context.n_prompts} concurrent completion requests...") assert context.n_prompts > 0 seeds = await completions_seed(context) + assert seeds is not None for prompt_no in range(context.n_prompts): shifted_args = [context.prompts.pop(), seeds[prompt_no], *args] context.concurrent_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs))) @@ -861,7 +865,7 @@ async def request_completion(prompt, id_slot=None, expect_api_error=None, user_api_key=None, - temperature=None): + temperature=None) -> int | dict[str, Any]: if debug: print(f"Sending completion request: {prompt}") origin = "my.super.domain" @@ -899,8 +903,8 @@ async def request_completion(prompt, async def oai_chat_completions(user_prompt, seed, system_prompt, - base_url, - base_path, + base_url: str, + base_path: str, async_client, debug=False, temperature=None, @@ -909,7 +913,7 @@ async def oai_chat_completions(user_prompt, enable_streaming=None, response_format=None, user_api_key=None, - expect_api_error=None): + expect_api_error=None) -> int | dict[str, Any]: if debug: print(f"Sending OAI Chat completions request: {user_prompt}") # openai client always expects an api key @@ -989,32 +993,35 @@ async def oai_chat_completions(user_prompt, else: try: openai.api_key = user_api_key - openai.api_base = f'{base_url}{base_path}' - chat_completion = openai.Completion.create( + openai.base_url = f'{base_url}{base_path.removesuffix("chat")}' + assert model is not None + chat_completion = openai.chat.completions.create( messages=payload['messages'], model=model, max_tokens=n_predict, stream=enable_streaming, - response_format=payload.get('response_format'), + response_format=payload.get('response_format') or openai.NOT_GIVEN, seed=seed, temperature=payload['temperature'] ) - except openai.error.AuthenticationError as e: + except openai.AuthenticationError as e: if expect_api_error is not None and expect_api_error: return 401 else: assert False, f'error raised: {e}' if enable_streaming: + chat_completion = cast(openai.Stream[ChatCompletionChunk], chat_completion) for chunk in chat_completion: assert len(chunk.choices) == 1 delta = chunk.choices[0].delta - if 'content' in delta: - completion_response['content'] += delta['content'] + if delta.content is not None: + completion_response['content'] += delta.content completion_response['timings']['predicted_n'] += 1 completion_response['truncated'] = chunk.choices[0].finish_reason != 'stop' else: assert len(chat_completion.choices) == 1 + assert chat_completion.usage is not None completion_response = { 'content': chat_completion.choices[0].message.content, 'timings': { @@ -1028,7 +1035,7 @@ async def oai_chat_completions(user_prompt, return completion_response -async def request_embedding(content, seed, base_url=None): +async def request_embedding(content, seed, base_url=None) -> list[list[float]]: async with aiohttp.ClientSession() as session: async with session.post(f'{base_url}/embedding', json={ @@ -1041,7 +1048,7 @@ async def request_embedding(content, seed, base_url=None): async def request_oai_embeddings(input, seed, base_url=None, user_api_key=None, - model=None, async_client=False): + model=None, async_client=False) -> list[list[float]]: # openai client always expects an api_key user_api_key = user_api_key if user_api_key is not None else 'nope' if async_client: @@ -1063,7 +1070,7 @@ async def request_oai_embeddings(input, seed, response_json = await response.json() assert response_json['model'] == model, f"invalid model received: {response_json['model']}" assert response_json['object'] == 'list' - if isinstance(input, collections.abc.Sequence): + if isinstance(input, Sequence): embeddings = [] for an_oai_embeddings in response_json['data']: embeddings.append(an_oai_embeddings['embedding']) @@ -1072,19 +1079,14 @@ async def request_oai_embeddings(input, seed, return embeddings else: openai.api_key = user_api_key - openai.api_base = f'{base_url}/v1' - oai_embeddings = openai.Embedding.create( + openai.base_url = f'{base_url}/v1/' + assert model is not None + oai_embeddings = openai.embeddings.create( model=model, input=input, ) - if isinstance(input, collections.abc.Sequence): - embeddings = [] - for an_oai_embeddings in oai_embeddings.data: - embeddings.append(an_oai_embeddings.embedding) - else: - embeddings = [oai_embeddings.data.embedding] - return embeddings + return [e.embedding for e in oai_embeddings.data] def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None): @@ -1122,7 +1124,7 @@ def assert_all_predictions_equal(completion_responses): if i == j: continue content_j = response_j['content'] - assert content_i == content_j, "contents not equal" + assert content_i == content_j, "contents not equal" def assert_all_predictions_different(completion_responses): @@ -1136,7 +1138,7 @@ def assert_all_predictions_different(completion_responses): if i == j: continue content_j = response_j['content'] - assert content_i != content_j, "contents not different" + assert content_i != content_j, "contents not different" def assert_all_token_probabilities_equal(completion_responses): @@ -1153,7 +1155,7 @@ def assert_all_token_probabilities_equal(completion_responses): if i == j: continue probs_j = response_j['completion_probabilities'][pos]['probs'] - assert probs_i == probs_j, "contents not equal" + assert probs_i == probs_j, "contents not equal" async def gather_tasks_results(context): @@ -1343,7 +1345,7 @@ def start_server_background(context): } context.server_process = subprocess.Popen( [str(arg) for arg in [context.server_path, *server_args]], - **pkwargs) + **pkwargs) # pyright: ignore[reportArgumentType, reportCallIssue] def server_log(in_stream, out_stream): for line in iter(in_stream.readline, b''): diff --git a/examples/server/tests/requirements.txt b/examples/server/tests/requirements.txt index 2e4f42ad2..2c741ea10 100644 --- a/examples/server/tests/requirements.txt +++ b/examples/server/tests/requirements.txt @@ -1,6 +1,6 @@ aiohttp~=3.9.3 behave~=1.2.6 huggingface_hub~=0.20.3 -numpy~=1.24.4 -openai~=0.25.0 +numpy~=1.26.4 +openai~=1.30.3 prometheus-client~=0.20.0 diff --git a/examples/server_embd.py b/examples/server_embd.py index a9a36a44c..0e34c6cea 100644 --- a/examples/server_embd.py +++ b/examples/server_embd.py @@ -1,13 +1,15 @@ import asyncio +import asyncio.threads import requests import numpy as np + n = 8 result = [] async def requests_post_async(*args, **kwargs): - return await asyncio.to_thread(requests.post, *args, **kwargs) + return await asyncio.threads.to_thread(requests.post, *args, **kwargs) async def main(): model_url = "http://127.0.0.1:6900" diff --git a/examples/train-text-from-scratch/convert_train_checkpoint_to_gguf.py b/examples/train-text-from-scratch/convert_train_checkpoint_to_gguf.py index ed93673bc..e045beb72 100644 --- a/examples/train-text-from-scratch/convert_train_checkpoint_to_gguf.py +++ b/examples/train-text-from-scratch/convert_train_checkpoint_to_gguf.py @@ -66,7 +66,7 @@ class Tensor: if len(self.ne) == 0: self.nbytes = 0 else: - self.nbytes = int(np.product(self.ne)) * 4 + self.nbytes = int(np.prod(self.ne)) * 4 else: raise ValueError(f"Unhandled data type '{self.dtype}'") diff --git a/ggml/ggml_vk_generate_shaders.py b/ggml/ggml_vk_generate_shaders.py index 38914eedb..41d5d9b8c 100644 --- a/ggml/ggml_vk_generate_shaders.py +++ b/ggml/ggml_vk_generate_shaders.py @@ -99,6 +99,8 @@ async def main(): tasks = [] + base_dict = {"FLOAT_TYPE": "float"} + for fp16 in (False, True): # MUL_MAT matmul_shaders(tasks, fp16, False) @@ -106,8 +108,6 @@ async def main(): matmul_shaders(tasks, fp16, True) for tname in type_names: - base_dict = {"FLOAT_TYPE": "float"} - # mul mat vec data_a_key = f"DATA_A_{tname.upper()}" shader = f"mul_mat_vec_{tname}.comp" if tname.endswith("_k") else "mul_mat_vec.comp" diff --git a/gguf-py/gguf/gguf_reader.py b/gguf-py/gguf/gguf_reader.py index 20432bd25..e8e61abf8 100644 --- a/gguf-py/gguf/gguf_reader.py +++ b/gguf-py/gguf/gguf_reader.py @@ -67,7 +67,7 @@ class ReaderTensor(NamedTuple): class GGUFReader: # I - same as host, S - swapped - byte_order: Literal['I'] | Literal['S'] = 'I' + byte_order: Literal['I', 'S'] = 'I' alignment: int = GGUF_DEFAULT_ALIGNMENT data_offset: int @@ -86,7 +86,7 @@ class GGUFReader: GGUFValueType.BOOL: np.bool_, } - def __init__(self, path: os.PathLike[str] | str, mode: Literal['r'] | Literal['r+'] | Literal['c'] = 'r'): + def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] = 'r'): self.data = np.memmap(path, mode = mode) offs = 0 @@ -140,7 +140,7 @@ class GGUFReader: return self.tensors[idx] def _get( - self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I'] | Literal['S'] | Literal['<'] = None, + self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I', 'S', '<'] = None, ) -> npt.NDArray[Any]: count = int(count) itemsize = int(np.empty([], dtype = dtype).itemsize) diff --git a/gguf-py/gguf/lazy.py b/gguf-py/gguf/lazy.py index 1167335b8..c50124cd9 100644 --- a/gguf-py/gguf/lazy.py +++ b/gguf-py/gguf/lazy.py @@ -16,16 +16,16 @@ logger = logging.getLogger(__name__) class LazyMeta(ABCMeta): def __new__(cls, name: str, bases: tuple[type, ...], namespace: dict[str, Any], **kwargs): - def __getattr__(self, __name: str) -> Any: - meta_attr = getattr(self._meta, __name) + def __getattr__(self, name: str) -> Any: + meta_attr = getattr(self._meta, name) if callable(meta_attr): return type(self)._wrap_fn( - (lambda s, *args, **kwargs: getattr(s, __name)(*args, **kwargs)), + (lambda s, *args, **kwargs: getattr(s, name)(*args, **kwargs)), use_self=self, ) elif isinstance(meta_attr, self._tensor_type): # e.g. self.T with torch.Tensor should still be wrapped - return type(self)._wrap_fn(lambda s: getattr(s, __name))(self) + return type(self)._wrap_fn(lambda s: getattr(s, name))(self) else: # no need to wrap non-tensor properties, # and they likely don't depend on the actual contents of the tensor @@ -141,19 +141,21 @@ class LazyBase(ABC, metaclass=LazyMeta): res = cls.meta_with_dtype_and_shape(meta_noop, res.shape) if isinstance(res, cls._tensor_type): - def collect_replace(t: LazyBase): - if collect_replace.shared_lazy is None: - collect_replace.shared_lazy = t._lazy - else: - collect_replace.shared_lazy.extend(t._lazy) - t._lazy = collect_replace.shared_lazy + class CollectSharedLazy: + # emulating a static variable + shared_lazy: None | deque[LazyBase] = None - # emulating a static variable - collect_replace.shared_lazy = None + @staticmethod + def collect_replace(t: LazyBase): + if CollectSharedLazy.shared_lazy is None: + CollectSharedLazy.shared_lazy = t._lazy + else: + CollectSharedLazy.shared_lazy.extend(t._lazy) + t._lazy = CollectSharedLazy.shared_lazy - LazyBase._recurse_apply(args, collect_replace) + LazyBase._recurse_apply(args, CollectSharedLazy.collect_replace) - shared_lazy = collect_replace.shared_lazy + shared_lazy = CollectSharedLazy.shared_lazy return cls(meta=cls.eager_to_meta(res), lazy=shared_lazy, args=args, func=lambda a: fn(*a, **kwargs)) else: @@ -184,6 +186,7 @@ class LazyBase(ABC, metaclass=LazyMeta): lt._args = cls._recurse_apply(lt._args, already_eager_to_eager) lt._data = lt._func(lt._args) # sanity check + assert lt._data is not None assert lt._data.dtype == lt._meta.dtype assert lt._data.shape == lt._meta.shape diff --git a/gguf-py/scripts/__init__.py b/gguf-py/scripts/__init__.py index f9d29cb69..e77f2e9c9 100644 --- a/gguf-py/scripts/__init__.py +++ b/gguf-py/scripts/__init__.py @@ -1,3 +1,5 @@ +# pyright: reportUnusedImport=false + from .gguf_convert_endian import main as gguf_convert_endian_entrypoint from .gguf_dump import main as gguf_dump_entrypoint from .gguf_set_metadata import main as gguf_set_metadata_entrypoint diff --git a/gguf-py/scripts/gguf_hash.py b/gguf-py/scripts/gguf_hash.py index 956775182..770b79a93 100755 --- a/gguf-py/scripts/gguf_hash.py +++ b/gguf-py/scripts/gguf_hash.py @@ -63,9 +63,9 @@ def gguf_hash(reader: GGUFReader, filename: str, disable_progress_bar) -> None: bar.update(sum_weights_in_tensor) sha1_layer = hashlib.sha1() - sha1_layer.update(tensor.data) - sha1.update(tensor.data) - uuidv5_sha1.update(tensor.data) + sha1_layer.update(tensor.data.data) + sha1.update(tensor.data.data) + uuidv5_sha1.update(tensor.data.data) print("sha1 {0} {1}:{2}".format(sha1_layer.hexdigest(), filename, tensor.name)) # noqa: NP100 # Flush Hash Progress Bar diff --git a/gguf-py/scripts/gguf_new_metadata.py b/gguf-py/scripts/gguf_new_metadata.py index c4b90d581..fce52a8c1 100755 --- a/gguf-py/scripts/gguf_new_metadata.py +++ b/gguf-py/scripts/gguf_new_metadata.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +from __future__ import annotations + import logging import argparse import os diff --git a/gguf-py/tests/test_gguf.py b/gguf-py/tests/test_gguf.py index 0adeb7d55..76b52181e 100644 --- a/gguf-py/tests/test_gguf.py +++ b/gguf-py/tests/test_gguf.py @@ -1,4 +1,4 @@ -import gguf # noqa: F401 +import gguf # noqa: F401 # pyright: ignore[reportUnusedImport] # TODO: add tests diff --git a/pyrightconfig.json b/pyrightconfig.json index 020a71a4e..6016f4b6d 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -1,3 +1,21 @@ { "extraPaths": ["gguf-py"], -} + "pythonVersion": "3.9", + "pythonPlatform": "All", + "reportUnusedImport": "warning", + "reportDuplicateImport": "error", + "reportDeprecated": "warning", + "reportUnnecessaryTypeIgnoreComment": "warning", + "executionEnvironments": [ + { + // TODO: make this version override work correctly + "root": "gguf-py", + "pythonVersion": "3.8", + }, + { + // uses match expressions in steps.py + "root": "examples/server/tests", + "pythonVersion": "3.10", + }, + ], + } diff --git a/requirements/requirements-all.txt b/requirements/requirements-all.txt new file mode 100644 index 000000000..94de59d7e --- /dev/null +++ b/requirements/requirements-all.txt @@ -0,0 +1,12 @@ +-r ../examples/llava/requirements.txt +-r ../examples/server/bench/requirements.txt +-r ../examples/server/tests/requirements.txt + +-r ./requirements-compare-llama-bench.txt +-r ./requirements-pydantic.txt +-r ./requirements-test-tokenizer-random.txt + +-r ./requirements-convert_hf_to_gguf.txt +-r ./requirements-convert_hf_to_gguf_update.txt +-r ./requirements-convert_legacy_llama.txt +-r ./requirements-convert_llama_ggml_to_gguf.txt diff --git a/requirements/requirements-compare-llama-bench.txt b/requirements/requirements-compare-llama-bench.txt new file mode 100644 index 000000000..e0aaa3204 --- /dev/null +++ b/requirements/requirements-compare-llama-bench.txt @@ -0,0 +1,2 @@ +tabulate~=0.9.0 +GitPython~=3.1.43 diff --git a/requirements/requirements-pydantic.txt b/requirements/requirements-pydantic.txt new file mode 100644 index 000000000..2f9455b14 --- /dev/null +++ b/requirements/requirements-pydantic.txt @@ -0,0 +1,2 @@ +docstring_parser~=0.15 +pydantic~=2.6.3 diff --git a/requirements/requirements-test-tokenizer-random.txt b/requirements/requirements-test-tokenizer-random.txt new file mode 100644 index 000000000..2785e71a2 --- /dev/null +++ b/requirements/requirements-test-tokenizer-random.txt @@ -0,0 +1 @@ +cffi~=1.16.0 diff --git a/scripts/check-requirements.sh b/scripts/check-requirements.sh index 48f924c02..d3bbded13 100755 --- a/scripts/check-requirements.sh +++ b/scripts/check-requirements.sh @@ -108,6 +108,11 @@ check_convert_script() { fatal "$py missing requirements. Expected: $reqs" fi + # Check that all sub-requirements are added to top-level requirements.txt + if ! grep -qF "$reqs" requirements.txt; then + fatal "$reqs needs to be added to requirements.txt" + fi + local venv="$workdir/$pyname-venv" python3 -m venv "$venv" @@ -134,12 +139,7 @@ EOF readonly ignore_eq_eq='check_requirements: ignore "=="' -for req in "$reqs_dir"/*; do - # Check that all sub-requirements are added to top-level requirements.txt - if ! grep -qF "$req" requirements.txt; then - fatal "$req needs to be added to requirements.txt" - fi - +for req in */**/requirements*.txt; do # Make sure exact release versions aren't being pinned in the requirements # Filters out the ignore string if grep -vF "$ignore_eq_eq" "$req" | grep -q '=='; then diff --git a/scripts/compare-llama-bench.py b/scripts/compare-llama-bench.py index 513dde5e1..92b9e682a 100755 --- a/scripts/compare-llama-bench.py +++ b/scripts/compare-llama-bench.py @@ -123,13 +123,13 @@ builds = cursor.execute("SELECT DISTINCT build_commit FROM test;").fetchall() try: repo = git.Repo(".", search_parent_directories=True) -except git.exc.InvalidGitRepositoryError: +except git.InvalidGitRepositoryError: repo = None -def find_parent_in_data(commit): +def find_parent_in_data(commit: git.Commit): """Helper function to find the most recent parent measured in number of commits for which there is data.""" - heap = [(0, commit)] + heap: list[tuple[int, git.Commit]] = [(0, commit)] seen_hexsha8 = set() while heap: depth, current_commit = heapq.heappop(heap) @@ -144,7 +144,7 @@ def find_parent_in_data(commit): return None -def get_all_parent_hexsha8s(commit): +def get_all_parent_hexsha8s(commit: git.Commit): """Helper function to recursively get hexsha8 values for all parents of a commit.""" unvisited = [commit] visited = [] diff --git a/scripts/gen-unicode-data.py b/scripts/gen-unicode-data.py index 890e4d7c2..2d9bde01c 100644 --- a/scripts/gen-unicode-data.py +++ b/scripts/gen-unicode-data.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import array import unicodedata import requests @@ -133,7 +135,7 @@ table_nfd.sort() # group ranges with same flags -ranges_flags = [(0, codepoint_flags[0])] # start, flags +ranges_flags: list[tuple[int, int]] = [(0, codepoint_flags[0])] # start, flags for codepoint, flags in enumerate(codepoint_flags): if flags != ranges_flags[-1][1]: ranges_flags.append((codepoint, flags)) @@ -141,11 +143,11 @@ ranges_flags.append((MAX_CODEPOINTS, 0x0000)) # group ranges with same nfd -ranges_nfd = [(0, 0, 0)] # start, last, nfd +ranges_nfd: list[tuple[int, int, int]] = [(0, 0, 0)] # start, last, nfd for codepoint, norm in table_nfd: start = ranges_nfd[-1][0] if ranges_nfd[-1] != (start, codepoint - 1, norm): - ranges_nfd.append(None) + ranges_nfd.append(None) # type: ignore[arg-type] # dummy, will be replaced below start = codepoint ranges_nfd[-1] = (start, codepoint, norm) @@ -179,13 +181,13 @@ for codepoint in table_whitespace: out("};\n") out("const std::unordered_map unicode_map_lowercase = {") -for tuple in table_lowercase: - out("{0x%06X, 0x%06X}," % tuple) +for tuple_lw in table_lowercase: + out("{0x%06X, 0x%06X}," % tuple_lw) out("};\n") out("const std::unordered_map unicode_map_uppercase = {") -for tuple in table_uppercase: - out("{0x%06X, 0x%06X}," % tuple) +for tuple_up in table_uppercase: + out("{0x%06X, 0x%06X}," % tuple_up) out("};\n") out("const std::vector unicode_ranges_nfd = { // start, last, nfd") diff --git a/tests/test-tokenizer-random.py b/tests/test-tokenizer-random.py index 48cab8a1e..c50a8ca32 100644 --- a/tests/test-tokenizer-random.py +++ b/tests/test-tokenizer-random.py @@ -6,6 +6,8 @@ # python3 tests/test-tokenizer-random.py ./models/ggml-vocab-llama-bpe.gguf ./models/tokenizers/llama-bpe # +from __future__ import annotations + import time import logging import argparse @@ -13,7 +15,9 @@ import subprocess import random import unicodedata -from typing import Iterator +from pathlib import Path +from typing import Any, Iterator, cast +from typing_extensions import Buffer import cffi from transformers import AutoTokenizer @@ -28,15 +32,15 @@ class LibLlama: DEFAULT_PATH_INCLUDES = ["./ggml/include/", "./include/"] DEFAULT_PATH_LIBLLAMA = "./build/src/libllama.so" # CMakeLists.txt: BUILD_SHARED_LIBS ON - def __init__(self, path_llama_h: str = None, path_includes: list[str] = [], path_libllama: str = None): + def __init__(self, path_llama_h: str | None = None, path_includes: list[str] = [], path_libllama: str | None = None): path_llama_h = path_llama_h or self.DEFAULT_PATH_LLAMA_H path_includes = path_includes or self.DEFAULT_PATH_INCLUDES path_libllama = path_libllama or self.DEFAULT_PATH_LIBLLAMA (self.ffi, self.lib) = self._load_libllama_cffi(path_llama_h, path_includes, path_libllama) self.lib.llama_backend_init() - def _load_libllama_cffi(self, path_llama_h: str, path_includes: list[str], path_libllama: str): - cmd = ["gcc", "-E", "-P", "-D__restrict=", "-D__attribute__(x)=", "-D__asm__(x)="] + def _load_libllama_cffi(self, path_llama_h: str, path_includes: list[str], path_libllama: str) -> tuple[cffi.FFI, Any]: + cmd = ["gcc", "-O0", "-E", "-P", "-D__restrict=", "-D__attribute__(x)=", "-D__asm__(x)="] cmd += ["-I" + path for path in path_includes] + [path_llama_h] res = subprocess.run(cmd, stdout=subprocess.PIPE) assert (res.returncode == 0) @@ -68,7 +72,7 @@ class LibLlama: class LibLlamaModel: def __init__(self, libllama: LibLlama, path_model: str, mparams={}, cparams={}): - self.lib = libllama.lib + self.lib: Any = libllama.lib self.ffi = libllama.ffi if isinstance(mparams, dict): mparams = libllama.model_default_params(**mparams) @@ -94,11 +98,11 @@ class LibLlamaModel: self.lib = None def tokenize(self, text: str, add_special: bool = False, parse_special: bool = False) -> list[int]: - text = text.encode("utf-8") - num = self.lib.llama_tokenize(self.model, text, len(text), self.token_ids, len(self.token_ids), add_special, parse_special) + encoded_text: bytes = text.encode("utf-8") + num = self.lib.llama_tokenize(self.model, encoded_text, len(encoded_text), self.token_ids, len(self.token_ids), add_special, parse_special) while num < 0 and len(self.token_ids) < (16 << 20): self.token_ids = self.ffi.new("llama_token[]", -2 * num) - num = self.lib.llama_tokenize(self.model, text, len(text), self.token_ids, len(self.token_ids), add_special, parse_special) + num = self.lib.llama_tokenize(self.model, encoded_text, len(encoded_text), self.token_ids, len(self.token_ids), add_special, parse_special) return list(self.token_ids[0:num]) def detokenize(self, ids: list[int], remove_special: bool = False, unparse_special: bool = False) -> str: @@ -110,7 +114,7 @@ class LibLlamaModel: while num < 0 and len(self.text_buff) < (16 << 20): self.text_buff = self.ffi.new("uint8_t[]", -2 * num) num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), remove_special, unparse_special) - return str(self.ffi.buffer(self.text_buff, num), encoding="utf-8", errors="replace") # replace errors with '\uFFFD' + return str(cast(Buffer, self.ffi.buffer(self.text_buff, num)), encoding="utf-8", errors="replace") # replace errors with '\uFFFD' class Tokenizer: @@ -152,7 +156,7 @@ class TokenizerGroundtruth (Tokenizer): class TokenizerLlamaCpp (Tokenizer): - libllama: LibLlama = None + libllama: LibLlama | None = None def __init__(self, vocab_file: str): if not self.libllama: @@ -404,7 +408,7 @@ def generator_random_vocab_words(tokenizer: TokenizerGroundtruth, iterations=100 def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]): - def find_first_mismatch(ids1: list[int], ids2: list[int]): + def find_first_mismatch(ids1: list[int] | str, ids2: list[int] | str): for i, (a, b) in enumerate(zip(ids1, ids2)): if a != b: return i @@ -433,7 +437,7 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl decode_errors = 0 MAX_ERRORS = 10 - logger.info("%s: %s" % (generator.__name__, "ini")) + logger.info("%s: %s" % (generator.__qualname__, "ini")) for text in generator: # print(repr(text), text.encode()) # print(repr(text), hex(ord(text[0])), text.encode()) @@ -472,13 +476,13 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl break t_total = time.perf_counter() - t_start - logger.info(f"{generator.__name__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}") + logger.info(f"{generator.__qualname__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}") -def main(argv: list[str] = None): +def main(argv: list[str] | None = None): parser = argparse.ArgumentParser() - parser.add_argument("vocab_file", help="path to vocab 'gguf' file") - parser.add_argument("dir_tokenizer", help="directory containing 'tokenizer.model' file") + parser.add_argument("vocab_file", type=str, help="path to vocab 'gguf' file") + parser.add_argument("dir_tokenizer", type=str, help="directory containing 'tokenizer.model' file") parser.add_argument("--verbose", action="store_true", help="increase output verbosity") args = parser.parse_args(argv) @@ -520,7 +524,7 @@ if __name__ == "__main__": format = "%(levelname)s %(message)s", ) - path_tokenizers = "./models/tokenizers/" + path_tokenizers = Path("./models/tokenizers/") path_vocab_format = "./models/ggml-vocab-%s.gguf" tokenizers = [ @@ -556,6 +560,6 @@ if __name__ == "__main__": for tokenizer in tokenizers: logger.info("-" * 50) logger.info(f"TOKENIZER: '{tokenizer}'") - vocab_file = path_vocab_format % tokenizer - dir_tokenizer = path_tokenizers + "/" + tokenizer - main([vocab_file, dir_tokenizer, "--verbose"]) + vocab_file = Path(path_vocab_format % tokenizer) + dir_tokenizer = path_tokenizers / tokenizer + main([str(vocab_file), str(dir_tokenizer), "--verbose"])