mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 02:44:36 +00:00
convert : fix python 3.8 support, modernize type annotations (#2916)
* convert : fix python 3.8 support * convert : sort imports * convert : fix required parameters in convert-llama-ggmlv3-to-gguf * convert : fix mypy errors in convert-llama-ggmlv3-to-gguf * convert : use PEP 585 generics and PEP 604 unions Now that we have `from __future__ import annotations`, we can use this modern syntax in Python 3.7 instead of restricting support to Python 3.9 or 3.10 respectively. * gguf.py : a tuple is already a tuple * add mypy.ini * convert : add necessary `type: ignore` comments * gguf-py: bump version
This commit is contained in:
parent
8afe228000
commit
92d0b751a7
@ -1,18 +1,21 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# HF falcon--> gguf conversion
|
# HF falcon--> gguf conversion
|
||||||
|
|
||||||
import gguf
|
from __future__ import annotations
|
||||||
import os
|
|
||||||
import sys
|
import argparse
|
||||||
import struct
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import struct
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import gguf
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import argparse
|
from transformers import AutoTokenizer # type: ignore[import]
|
||||||
|
|
||||||
from typing import Any, List
|
|
||||||
from pathlib import Path
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
def bytes_to_unicode():
|
def bytes_to_unicode():
|
||||||
# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
|
# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
|
||||||
@ -114,9 +117,9 @@ gguf_writer.add_file_type(ftype)
|
|||||||
|
|
||||||
print("gguf: get tokenizer metadata")
|
print("gguf: get tokenizer metadata")
|
||||||
|
|
||||||
tokens: List[bytearray] = []
|
tokens: list[bytearray] = []
|
||||||
scores: List[float] = []
|
scores: list[float] = []
|
||||||
toktypes: List[int] = []
|
toktypes: list[int] = []
|
||||||
|
|
||||||
tokenizer_json_file = dir_model / 'tokenizer.json'
|
tokenizer_json_file = dir_model / 'tokenizer.json'
|
||||||
if not tokenizer_json_file.is_file():
|
if not tokenizer_json_file.is_file():
|
||||||
|
@ -1,18 +1,20 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# HF gptneox--> gguf conversion
|
# HF gptneox--> gguf conversion
|
||||||
|
|
||||||
import gguf
|
from __future__ import annotations
|
||||||
import os
|
|
||||||
import sys
|
import argparse
|
||||||
import struct
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import struct
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import gguf
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import argparse
|
from transformers import AutoTokenizer # type: ignore[import]
|
||||||
|
|
||||||
from typing import Any, List
|
|
||||||
from pathlib import Path
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
|
# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
|
||||||
|
|
||||||
@ -112,7 +114,7 @@ gguf_writer.add_layer_norm_eps(hparams["layer_norm_eps"])
|
|||||||
|
|
||||||
print("gguf: get tokenizer metadata")
|
print("gguf: get tokenizer metadata")
|
||||||
|
|
||||||
tokens: List[bytearray] = []
|
tokens: list[bytearray] = []
|
||||||
|
|
||||||
tokenizer_json_file = dir_model / 'tokenizer.json'
|
tokenizer_json_file = dir_model / 'tokenizer.json'
|
||||||
if not tokenizer_json_file.is_file():
|
if not tokenizer_json_file.is_file():
|
||||||
|
@ -3,22 +3,25 @@
|
|||||||
# Only models with a single datafile are supported, like 7B
|
# Only models with a single datafile are supported, like 7B
|
||||||
# HF files required in the model dir: config.json tokenizer_config.json tokenizer.json tokenizer.model
|
# HF files required in the model dir: config.json tokenizer_config.json tokenizer.json tokenizer.model
|
||||||
|
|
||||||
import gguf
|
from __future__ import annotations
|
||||||
import os
|
|
||||||
import sys
|
import argparse
|
||||||
import struct
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import struct
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
import gguf
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import argparse
|
from sentencepiece import SentencePieceProcessor # type: ignore[import]
|
||||||
|
|
||||||
from typing import Any, List, TypeAlias
|
if TYPE_CHECKING:
|
||||||
from pathlib import Path
|
from typing import TypeAlias
|
||||||
from sentencepiece import SentencePieceProcessor
|
|
||||||
|
|
||||||
#NDArray = np.ndarray[Any, Any]
|
NDArray: TypeAlias = 'np.ndarray[Any, Any]'
|
||||||
# compatible with python < 3.9
|
|
||||||
NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]'
|
|
||||||
|
|
||||||
|
|
||||||
def count_model_parts(dir_model: Path) -> int:
|
def count_model_parts(dir_model: Path) -> int:
|
||||||
@ -129,9 +132,9 @@ if "rope_scaling" in hparams and hparams["rope_scaling"] != None and "factor" in
|
|||||||
|
|
||||||
print("gguf: get tokenizer metadata")
|
print("gguf: get tokenizer metadata")
|
||||||
|
|
||||||
tokens: List[bytes] = []
|
tokens: list[bytes] = []
|
||||||
scores: List[float] = []
|
scores: list[float] = []
|
||||||
toktypes: List[int] = []
|
toktypes: list[int] = []
|
||||||
|
|
||||||
tokenizer_model_file = dir_model / 'tokenizer.model'
|
tokenizer_model_file = dir_model / 'tokenizer.model'
|
||||||
if not tokenizer_model_file.is_file():
|
if not tokenizer_model_file.is_file():
|
||||||
|
@ -1,10 +1,14 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import sys, struct, math, argparse
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
import struct
|
||||||
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import gguf
|
import gguf
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
# Note: Does not support GGML_QKK_64
|
# Note: Does not support GGML_QKK_64
|
||||||
QK_K = 256
|
QK_K = 256
|
||||||
@ -72,7 +76,7 @@ class Vocab:
|
|||||||
class Tensor:
|
class Tensor:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.name = None
|
self.name = None
|
||||||
self.dims = ()
|
self.dims: tuple[int, ...] = ()
|
||||||
self.dtype = None
|
self.dtype = None
|
||||||
self.start_offset = 0
|
self.start_offset = 0
|
||||||
self.len_bytes = np.int64(0)
|
self.len_bytes = np.int64(0)
|
||||||
@ -119,7 +123,7 @@ class GGMLV3Model:
|
|||||||
offset += hp.load(data, offset)
|
offset += hp.load(data, offset)
|
||||||
vocab = Vocab()
|
vocab = Vocab()
|
||||||
offset += vocab.load(data, offset, hp.n_vocab)
|
offset += vocab.load(data, offset, hp.n_vocab)
|
||||||
tensors = []
|
tensors: list[Tensor] = []
|
||||||
tensor_map = {}
|
tensor_map = {}
|
||||||
while offset < len(data):
|
while offset < len(data):
|
||||||
tensor = Tensor()
|
tensor = Tensor()
|
||||||
@ -305,8 +309,8 @@ def handle_metadata(cfg, hp):
|
|||||||
|
|
||||||
def handle_args():
|
def handle_args():
|
||||||
parser = argparse.ArgumentParser(description = 'Convert GGMLv3 models to GGUF')
|
parser = argparse.ArgumentParser(description = 'Convert GGMLv3 models to GGUF')
|
||||||
parser.add_argument('--input', '-i', type = Path, help = 'Input GGMLv3 filename')
|
parser.add_argument('--input', '-i', type = Path, required = True, help = 'Input GGMLv3 filename')
|
||||||
parser.add_argument('--output', '-o', type = Path, help ='Output GGUF filename')
|
parser.add_argument('--output', '-o', type = Path, required = True, help ='Output GGUF filename')
|
||||||
parser.add_argument('--name', help = 'Set model name')
|
parser.add_argument('--name', help = 'Set model name')
|
||||||
parser.add_argument('--desc', help = 'Set model description')
|
parser.add_argument('--desc', help = 'Set model description')
|
||||||
parser.add_argument('--gqa', type = int, default = 1, help = 'grouped-query attention factor (use 8 for LLaMA2 70B)')
|
parser.add_argument('--gqa', type = int, default = 1, help = 'grouped-query attention factor (use 8 for LLaMA2 70B)')
|
||||||
|
@ -1,28 +1,31 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# HF llama --> gguf conversion
|
# HF llama --> gguf conversion
|
||||||
|
|
||||||
import gguf
|
from __future__ import annotations
|
||||||
import os
|
|
||||||
import sys
|
import argparse
|
||||||
import struct
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import struct
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
import gguf
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import argparse
|
from sentencepiece import SentencePieceProcessor # type: ignore[import]
|
||||||
|
|
||||||
from typing import Any, List, Optional, TypeAlias
|
if TYPE_CHECKING:
|
||||||
from pathlib import Path
|
from typing import TypeAlias
|
||||||
from sentencepiece import SentencePieceProcessor
|
|
||||||
|
|
||||||
#NDArray = np.ndarray[Any, Any]
|
NDArray: TypeAlias = 'np.ndarray[Any, Any]'
|
||||||
# compatible with python < 3.9
|
|
||||||
NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]'
|
|
||||||
|
|
||||||
# reverse HF permute back to original pth layout
|
# reverse HF permute back to original pth layout
|
||||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py
|
||||||
|
|
||||||
|
|
||||||
def reverse_hf_permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray:
|
def reverse_hf_permute(weights: NDArray, n_head: int, n_kv_head: int | None = None) -> NDArray:
|
||||||
if n_kv_head is not None and n_head != n_kv_head:
|
if n_kv_head is not None and n_head != n_kv_head:
|
||||||
n_head //= n_kv_head
|
n_head //= n_kv_head
|
||||||
|
|
||||||
@ -136,9 +139,9 @@ if "rope_scaling" in hparams and hparams["rope_scaling"] != None and "factor" in
|
|||||||
|
|
||||||
print("gguf: get tokenizer metadata")
|
print("gguf: get tokenizer metadata")
|
||||||
|
|
||||||
tokens: List[bytes] = []
|
tokens: list[bytes] = []
|
||||||
scores: List[float] = []
|
scores: list[float] = []
|
||||||
toktypes: List[int] = []
|
toktypes: list[int] = []
|
||||||
|
|
||||||
tokenizer_model_file = dir_model / 'tokenizer.model'
|
tokenizer_model_file = dir_model / 'tokenizer.model'
|
||||||
if not tokenizer_model_file.is_file():
|
if not tokenizer_model_file.is_file():
|
||||||
|
@ -1,15 +1,17 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Dict, Sequence, BinaryIO
|
from typing import Any, BinaryIO, Sequence
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
NUMPY_TYPE_TO_FTYPE: Dict[str, int] = {"float32": 0, "float16": 1}
|
NUMPY_TYPE_TO_FTYPE: dict[str, int] = {"float32": 0, "float16": 1}
|
||||||
|
|
||||||
|
|
||||||
HF_SUBLAYER_TO_GGML = {
|
HF_SUBLAYER_TO_GGML = {
|
||||||
@ -46,7 +48,7 @@ def translate_tensor_name(t: str) -> str:
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def write_file_header(fout: BinaryIO, params: Dict[str, Any]) -> None:
|
def write_file_header(fout: BinaryIO, params: dict[str, Any]) -> None:
|
||||||
fout.write(b"ggla"[::-1]) # magic (ggml lora)
|
fout.write(b"ggla"[::-1]) # magic (ggml lora)
|
||||||
fout.write(struct.pack("i", 1)) # file version
|
fout.write(struct.pack("i", 1)) # file version
|
||||||
fout.write(struct.pack("i", params["r"]))
|
fout.write(struct.pack("i", params["r"]))
|
||||||
|
149
convert.py
149
convert.py
@ -1,9 +1,8 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import gguf
|
|
||||||
import argparse
|
import argparse
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
|
||||||
import copy
|
import copy
|
||||||
import enum
|
import enum
|
||||||
import faulthandler
|
import faulthandler
|
||||||
@ -20,21 +19,23 @@ import struct
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import zipfile
|
import zipfile
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
|
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Literal, Optional, Sequence, Set, Tuple, Type, TypeVar, Union)
|
from typing import IO, TYPE_CHECKING, Any, Callable, Generator, Iterable, Literal, Sequence, TypeVar
|
||||||
from sentencepiece import SentencePieceProcessor # type: ignore
|
|
||||||
|
import gguf
|
||||||
|
import numpy as np
|
||||||
|
from sentencepiece import SentencePieceProcessor # type: ignore[import]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing_extensions import TypeAlias
|
from typing import TypeAlias
|
||||||
|
|
||||||
if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'):
|
if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'):
|
||||||
faulthandler.register(signal.SIGUSR1)
|
faulthandler.register(signal.SIGUSR1)
|
||||||
|
|
||||||
NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]'
|
NDArray: TypeAlias = 'np.ndarray[Any, Any]'
|
||||||
|
|
||||||
ARCH=gguf.MODEL_ARCH.LLAMA
|
ARCH=gguf.MODEL_ARCH.LLAMA
|
||||||
NAMES=gguf.MODEL_TENSOR_NAMES[ARCH]
|
NAMES=gguf.MODEL_TENSOR_NAMES[ARCH]
|
||||||
@ -47,8 +48,8 @@ DEFAULT_CONCURRENCY = 8
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class DataType:
|
class DataType:
|
||||||
name: str
|
name: str
|
||||||
dtype: 'np.dtype[Any]'
|
dtype: np.dtype[Any]
|
||||||
valid_conversions: List[str]
|
valid_conversions: list[str]
|
||||||
|
|
||||||
def elements_to_bytes(self, n_elements: int) -> int:
|
def elements_to_bytes(self, n_elements: int) -> int:
|
||||||
return n_elements * self.dtype.itemsize
|
return n_elements * self.dtype.itemsize
|
||||||
@ -65,7 +66,7 @@ DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_convers
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class QuantizedDataType(DataType):
|
class QuantizedDataType(DataType):
|
||||||
block_size: int
|
block_size: int
|
||||||
quantized_dtype: 'np.dtype[Any]'
|
quantized_dtype: np.dtype[Any]
|
||||||
ggml_type: gguf.GGMLQuantizationType
|
ggml_type: gguf.GGMLQuantizationType
|
||||||
|
|
||||||
def quantize(self, arr: NDArray) -> NDArray:
|
def quantize(self, arr: NDArray) -> NDArray:
|
||||||
@ -84,7 +85,7 @@ class Q8_0QuantizedDataType(QuantizedDataType):
|
|||||||
n_blocks = arr.size // self.block_size
|
n_blocks = arr.size // self.block_size
|
||||||
blocks = arr.reshape((n_blocks, self.block_size))
|
blocks = arr.reshape((n_blocks, self.block_size))
|
||||||
# Much faster implementation of block quantization contributed by @Cebtenzzre
|
# Much faster implementation of block quantization contributed by @Cebtenzzre
|
||||||
def quantize_blocks_q8_0(blocks: NDArray) -> Iterable[Tuple[Any, Any]]:
|
def quantize_blocks_q8_0(blocks: NDArray) -> Iterable[tuple[Any, Any]]:
|
||||||
d = abs(blocks).max(axis = 1) / np.float32(127)
|
d = abs(blocks).max(axis = 1) / np.float32(127)
|
||||||
with np.errstate(divide = 'ignore'):
|
with np.errstate(divide = 'ignore'):
|
||||||
qs = (blocks / d[:, None]).round()
|
qs = (blocks / d[:, None]).round()
|
||||||
@ -98,13 +99,13 @@ DT_Q8_0 = Q8_0QuantizedDataType('Q8_0',
|
|||||||
quantized_dtype = np.dtype([('d', '<f2'), ('qs', 'i1', (32,))]))
|
quantized_dtype = np.dtype([('d', '<f2'), ('qs', 'i1', (32,))]))
|
||||||
|
|
||||||
# Quantized types skipped here because they may also map to np.float32
|
# Quantized types skipped here because they may also map to np.float32
|
||||||
NUMPY_TYPE_TO_DATA_TYPE: Dict['np.dtype[Any]', DataType] = {}
|
NUMPY_TYPE_TO_DATA_TYPE: dict[np.dtype[Any], DataType] = {}
|
||||||
for dt in (DT_BF16, DT_F16, DT_F32, DT_I32):
|
for dt in (DT_BF16, DT_F16, DT_F32, DT_I32):
|
||||||
if dt.dtype in NUMPY_TYPE_TO_DATA_TYPE:
|
if dt.dtype in NUMPY_TYPE_TO_DATA_TYPE:
|
||||||
raise ValueError(f'Invalid duplicate data type {dt}')
|
raise ValueError(f'Invalid duplicate data type {dt}')
|
||||||
NUMPY_TYPE_TO_DATA_TYPE[dt.dtype] = dt
|
NUMPY_TYPE_TO_DATA_TYPE[dt.dtype] = dt
|
||||||
|
|
||||||
SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {
|
SAFETENSORS_DATA_TYPES: dict[str, DataType] = {
|
||||||
'BF16': DT_BF16,
|
'BF16': DT_BF16,
|
||||||
'F16': DT_F16,
|
'F16': DT_F16,
|
||||||
'F32': DT_F32,
|
'F32': DT_F32,
|
||||||
@ -119,14 +120,14 @@ class GGMLFileType(enum.IntEnum):
|
|||||||
MostlyF16 = 1 # except 1d tensors
|
MostlyF16 = 1 # except 1d tensors
|
||||||
MostlyQ8_0 = 7 # except 1d tensors
|
MostlyQ8_0 = 7 # except 1d tensors
|
||||||
|
|
||||||
def type_for_tensor(self, name: str, tensor: 'LazyTensor') -> DataType:
|
def type_for_tensor(self, name: str, tensor: LazyTensor) -> DataType:
|
||||||
dt = GGML_FILE_TYPE_TO_DATA_TYPE.get(self)
|
dt = GGML_FILE_TYPE_TO_DATA_TYPE.get(self)
|
||||||
if dt is None:
|
if dt is None:
|
||||||
raise ValueError(self)
|
raise ValueError(self)
|
||||||
# 1D tensors are always F32.
|
# 1D tensors are always F32.
|
||||||
return dt if len(tensor.shape) > 1 else DT_F32
|
return dt if len(tensor.shape) > 1 else DT_F32
|
||||||
|
|
||||||
GGML_FILE_TYPE_TO_DATA_TYPE: Dict[GGMLFileType, DataType] = {
|
GGML_FILE_TYPE_TO_DATA_TYPE: dict[GGMLFileType, DataType] = {
|
||||||
GGMLFileType.AllF32 : DT_F32,
|
GGMLFileType.AllF32 : DT_F32,
|
||||||
GGMLFileType.MostlyF16 : DT_F16,
|
GGMLFileType.MostlyF16 : DT_F16,
|
||||||
GGMLFileType.MostlyQ8_0: DT_Q8_0,
|
GGMLFileType.MostlyQ8_0: DT_Q8_0,
|
||||||
@ -148,13 +149,13 @@ class Params:
|
|||||||
n_head_kv: int
|
n_head_kv: int
|
||||||
f_norm_eps: float
|
f_norm_eps: float
|
||||||
|
|
||||||
f_rope_freq_base: Optional[float] = None
|
f_rope_freq_base: float | None = None
|
||||||
f_rope_scale: Optional[float] = None
|
f_rope_scale: float | None = None
|
||||||
|
|
||||||
ftype: Optional[GGMLFileType] = None
|
ftype: GGMLFileType | None = None
|
||||||
|
|
||||||
# path to the directory containing the model files
|
# path to the directory containing the model files
|
||||||
path_model: Optional['Path'] = None
|
path_model: Path | None = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def find_n_mult(n_ff: int, n_embd: int) -> int:
|
def find_n_mult(n_ff: int, n_embd: int) -> int:
|
||||||
@ -166,7 +167,7 @@ class Params:
|
|||||||
raise Exception(f"failed to find n_mult for (n_ff={n_ff}, n_embd={n_embd}).")
|
raise Exception(f"failed to find n_mult for (n_ff={n_ff}, n_embd={n_embd}).")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def guessed(model: 'LazyModel') -> 'Params':
|
def guessed(model: LazyModel) -> Params:
|
||||||
# try transformer naming first
|
# try transformer naming first
|
||||||
n_vocab, n_embd = model["model.embed_tokens.weight"].shape if "model.embed_tokens.weight" in model else model["tok_embeddings.weight"].shape
|
n_vocab, n_embd = model["model.embed_tokens.weight"].shape if "model.embed_tokens.weight" in model else model["tok_embeddings.weight"].shape
|
||||||
|
|
||||||
@ -202,7 +203,7 @@ class Params:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def loadHFTransformerJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
|
def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params:
|
||||||
config = json.load(open(config_path))
|
config = json.load(open(config_path))
|
||||||
|
|
||||||
n_vocab = config["vocab_size"]
|
n_vocab = config["vocab_size"]
|
||||||
@ -247,7 +248,7 @@ class Params:
|
|||||||
# LLaMA v2 70B params.json
|
# LLaMA v2 70B params.json
|
||||||
# {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1
|
# {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def loadOriginalParamsJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
|
def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
|
||||||
config = json.load(open(config_path))
|
config = json.load(open(config_path))
|
||||||
|
|
||||||
n_vocab = config["vocab_size"] if "vocab_size" in config else -1
|
n_vocab = config["vocab_size"] if "vocab_size" in config else -1
|
||||||
@ -291,7 +292,7 @@ class Params:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(model_plus: 'ModelPlus') -> 'Params':
|
def load(model_plus: ModelPlus) -> Params:
|
||||||
hf_config_path = model_plus.paths[0].parent / "config.json"
|
hf_config_path = model_plus.paths[0].parent / "config.json"
|
||||||
orig_config_path = model_plus.paths[0].parent / "params.json"
|
orig_config_path = model_plus.paths[0].parent / "params.json"
|
||||||
|
|
||||||
@ -314,9 +315,9 @@ class Params:
|
|||||||
#
|
#
|
||||||
|
|
||||||
class BpeVocab:
|
class BpeVocab:
|
||||||
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None:
|
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None:
|
||||||
self.bpe_tokenizer = json.loads(open(str(fname_tokenizer), encoding="utf-8").read())
|
self.bpe_tokenizer = json.loads(open(str(fname_tokenizer), encoding="utf-8").read())
|
||||||
added_tokens: Dict[str, int]
|
added_tokens: dict[str, int]
|
||||||
if fname_added_tokens is not None:
|
if fname_added_tokens is not None:
|
||||||
added_tokens = json.load(open(fname_added_tokens, encoding="utf-8"))
|
added_tokens = json.load(open(fname_added_tokens, encoding="utf-8"))
|
||||||
else:
|
else:
|
||||||
@ -335,9 +336,9 @@ class BpeVocab:
|
|||||||
self.fname_tokenizer = fname_tokenizer
|
self.fname_tokenizer = fname_tokenizer
|
||||||
self.fname_added_tokens = fname_added_tokens
|
self.fname_added_tokens = fname_added_tokens
|
||||||
|
|
||||||
def bpe_tokens(self) -> Iterable[Tuple[bytes, float, gguf.TokenType]]:
|
def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
||||||
tokenizer = self.bpe_tokenizer
|
tokenizer = self.bpe_tokenizer
|
||||||
from transformers.models.gpt2 import tokenization_gpt2
|
from transformers.models.gpt2 import tokenization_gpt2 # type: ignore[import]
|
||||||
byte_encoder = tokenization_gpt2.bytes_to_unicode()
|
byte_encoder = tokenization_gpt2.bytes_to_unicode()
|
||||||
byte_decoder = {v: k for k, v in byte_encoder.items()}
|
byte_decoder = {v: k for k, v in byte_encoder.items()}
|
||||||
for i, item in enumerate(tokenizer):
|
for i, item in enumerate(tokenizer):
|
||||||
@ -345,12 +346,12 @@ class BpeVocab:
|
|||||||
score: float = -i
|
score: float = -i
|
||||||
yield text, score, gguf.TokenType.USER_DEFINED
|
yield text, score, gguf.TokenType.USER_DEFINED
|
||||||
|
|
||||||
def added_tokens(self) -> Iterable[Tuple[bytes, float, gguf.TokenType]]:
|
def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
||||||
for text in self.added_tokens_list:
|
for text in self.added_tokens_list:
|
||||||
score = -1000.0
|
score = -1000.0
|
||||||
yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED
|
yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED
|
||||||
|
|
||||||
def all_tokens(self) -> Iterable[Tuple[bytes, float, gguf.TokenType]]:
|
def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
||||||
yield from self.bpe_tokens()
|
yield from self.bpe_tokens()
|
||||||
yield from self.added_tokens()
|
yield from self.added_tokens()
|
||||||
|
|
||||||
@ -359,9 +360,9 @@ class BpeVocab:
|
|||||||
|
|
||||||
|
|
||||||
class SentencePieceVocab:
|
class SentencePieceVocab:
|
||||||
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None:
|
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None:
|
||||||
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
|
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
|
||||||
added_tokens: Dict[str, int]
|
added_tokens: dict[str, int]
|
||||||
if fname_added_tokens is not None:
|
if fname_added_tokens is not None:
|
||||||
added_tokens = json.load(open(fname_added_tokens, encoding="utf-8"))
|
added_tokens = json.load(open(fname_added_tokens, encoding="utf-8"))
|
||||||
else:
|
else:
|
||||||
@ -380,7 +381,7 @@ class SentencePieceVocab:
|
|||||||
self.fname_tokenizer = fname_tokenizer
|
self.fname_tokenizer = fname_tokenizer
|
||||||
self.fname_added_tokens = fname_added_tokens
|
self.fname_added_tokens = fname_added_tokens
|
||||||
|
|
||||||
def sentencepiece_tokens(self) -> Iterable[Tuple[bytes, float, gguf.TokenType]]:
|
def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
||||||
tokenizer = self.sentencepiece_tokenizer
|
tokenizer = self.sentencepiece_tokenizer
|
||||||
for i in range(tokenizer.vocab_size()):
|
for i in range(tokenizer.vocab_size()):
|
||||||
piece = tokenizer.id_to_piece(i)
|
piece = tokenizer.id_to_piece(i)
|
||||||
@ -404,19 +405,19 @@ class SentencePieceVocab:
|
|||||||
|
|
||||||
yield text, score, toktype
|
yield text, score, toktype
|
||||||
|
|
||||||
def added_tokens(self) -> Iterable[Tuple[bytes, float, gguf.TokenType]]:
|
def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
||||||
for text in self.added_tokens_list:
|
for text in self.added_tokens_list:
|
||||||
score = -1000.0
|
score = -1000.0
|
||||||
yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED
|
yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED
|
||||||
|
|
||||||
def all_tokens(self) -> Iterable[Tuple[bytes, float, gguf.TokenType]]:
|
def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
||||||
yield from self.sentencepiece_tokens()
|
yield from self.sentencepiece_tokens()
|
||||||
yield from self.added_tokens()
|
yield from self.added_tokens()
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
|
return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
|
||||||
|
|
||||||
Vocab = Union[BpeVocab, SentencePieceVocab]
|
Vocab: TypeAlias = 'BpeVocab | SentencePieceVocab'
|
||||||
|
|
||||||
#
|
#
|
||||||
# data loading
|
# data loading
|
||||||
@ -436,15 +437,15 @@ class Tensor(metaclass=ABCMeta):
|
|||||||
data_type: DataType
|
data_type: DataType
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def astype(self, data_type: DataType) -> 'Tensor': ...
|
def astype(self, data_type: DataType) -> Tensor: ...
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def permute(self, n_head: int, n_head_kv: int) -> 'Tensor': ...
|
def permute(self, n_head: int, n_head_kv: int) -> Tensor: ...
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> 'UnquantizedTensor': ...
|
def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor: ...
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def part(self, n_part: int) -> 'UnquantizedTensor': ...
|
def part(self, n_part: int) -> UnquantizedTensor: ...
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def to_ggml(self) -> 'GGMLCompatibleTensor': ...
|
def to_ggml(self) -> GGMLCompatibleTensor: ...
|
||||||
|
|
||||||
|
|
||||||
def bf16_to_fp32(bf16_arr: np.ndarray[Any, np.dtype[np.uint16]]) -> NDArray:
|
def bf16_to_fp32(bf16_arr: np.ndarray[Any, np.dtype[np.uint16]]) -> NDArray:
|
||||||
@ -465,22 +466,22 @@ class UnquantizedTensor(Tensor):
|
|||||||
self.ndarray = bf16_to_fp32(self.ndarray)
|
self.ndarray = bf16_to_fp32(self.ndarray)
|
||||||
return UnquantizedTensor(self.ndarray.astype(dtype))
|
return UnquantizedTensor(self.ndarray.astype(dtype))
|
||||||
|
|
||||||
def to_ggml(self) -> 'UnquantizedTensor':
|
def to_ggml(self) -> UnquantizedTensor:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> 'UnquantizedTensor':
|
def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor:
|
||||||
r = self.ndarray.shape[0] // 3
|
r = self.ndarray.shape[0] // 3
|
||||||
return UnquantizedTensor(permute(self.ndarray[r * n_part : r * n_part + r, ...], n_head, n_head_kv))
|
return UnquantizedTensor(permute(self.ndarray[r * n_part : r * n_part + r, ...], n_head, n_head_kv))
|
||||||
|
|
||||||
def part(self, n_part: int) -> 'UnquantizedTensor':
|
def part(self, n_part: int) -> UnquantizedTensor:
|
||||||
r = self.ndarray.shape[0] // 3
|
r = self.ndarray.shape[0] // 3
|
||||||
return UnquantizedTensor(self.ndarray[r * n_part : r * n_part + r, ...])
|
return UnquantizedTensor(self.ndarray[r * n_part : r * n_part + r, ...])
|
||||||
|
|
||||||
def permute(self, n_head: int, n_head_kv: int) -> 'UnquantizedTensor':
|
def permute(self, n_head: int, n_head_kv: int) -> UnquantizedTensor:
|
||||||
return UnquantizedTensor(permute(self.ndarray, n_head, n_head_kv))
|
return UnquantizedTensor(permute(self.ndarray, n_head, n_head_kv))
|
||||||
|
|
||||||
|
|
||||||
def load_unquantized(lazy_tensor: 'LazyTensor', expected_dtype: Any = None, convert: bool = False) -> NDArray:
|
def load_unquantized(lazy_tensor: LazyTensor, expected_dtype: Any = None, convert: bool = False) -> NDArray:
|
||||||
tensor = lazy_tensor.load()
|
tensor = lazy_tensor.load()
|
||||||
assert isinstance(tensor, UnquantizedTensor)
|
assert isinstance(tensor, UnquantizedTensor)
|
||||||
|
|
||||||
@ -496,13 +497,13 @@ def load_unquantized(lazy_tensor: 'LazyTensor', expected_dtype: Any = None, conv
|
|||||||
return tensor.ndarray
|
return tensor.ndarray
|
||||||
|
|
||||||
|
|
||||||
GGMLCompatibleTensor = Union[UnquantizedTensor]
|
GGMLCompatibleTensor = UnquantizedTensor
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LazyTensor:
|
class LazyTensor:
|
||||||
_load: Callable[[], Tensor]
|
_load: Callable[[], Tensor]
|
||||||
shape: List[int]
|
shape: list[int]
|
||||||
data_type: DataType
|
data_type: DataType
|
||||||
description: str
|
description: str
|
||||||
|
|
||||||
@ -513,7 +514,7 @@ class LazyTensor:
|
|||||||
(self.data_type, ret.data_type, self.description)
|
(self.data_type, ret.data_type, self.description)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def astype(self, data_type: DataType) -> 'LazyTensor':
|
def astype(self, data_type: DataType) -> LazyTensor:
|
||||||
self.validate_conversion_to(data_type)
|
self.validate_conversion_to(data_type)
|
||||||
|
|
||||||
def load() -> Tensor:
|
def load() -> Tensor:
|
||||||
@ -525,24 +526,24 @@ class LazyTensor:
|
|||||||
raise ValueError(f'Cannot validate conversion from {self.data_type} to {data_type}.')
|
raise ValueError(f'Cannot validate conversion from {self.data_type} to {data_type}.')
|
||||||
|
|
||||||
|
|
||||||
LazyModel = Dict[str, LazyTensor]
|
LazyModel = dict[str, LazyTensor]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelPlus:
|
class ModelPlus:
|
||||||
model: LazyModel
|
model: LazyModel
|
||||||
paths: List[Path] # Where this was read from.
|
paths: list[Path] # Where this was read from.
|
||||||
format: Literal['ggml', 'torch', 'safetensors', 'none']
|
format: Literal['ggml', 'torch', 'safetensors', 'none']
|
||||||
vocab: Optional[Vocab] # For GGML models (which have vocab built in), the vocab.
|
vocab: Vocab | None # For GGML models (which have vocab built in), the vocab.
|
||||||
|
|
||||||
|
|
||||||
def merge_sharded(models: List[LazyModel]) -> LazyModel:
|
def merge_sharded(models: list[LazyModel]) -> LazyModel:
|
||||||
# Original LLaMA models have each file contain one part of each tensor.
|
# Original LLaMA models have each file contain one part of each tensor.
|
||||||
# Use a dict instead of a set to preserve order.
|
# Use a dict instead of a set to preserve order.
|
||||||
names = {name: None for model in models for name in model}
|
names = {name: None for model in models for name in model}
|
||||||
|
|
||||||
def convert(name: str) -> LazyTensor:
|
def convert(name: str) -> LazyTensor:
|
||||||
lazy_tensors: List[LazyTensor] = [model[name] for model in models]
|
lazy_tensors: list[LazyTensor] = [model[name] for model in models]
|
||||||
if len(lazy_tensors) == 1:
|
if len(lazy_tensors) == 1:
|
||||||
# only one file; don't go through this procedure since there might
|
# only one file; don't go through this procedure since there might
|
||||||
# be quantized tensors
|
# be quantized tensors
|
||||||
@ -570,7 +571,7 @@ def merge_sharded(models: List[LazyModel]) -> LazyModel:
|
|||||||
return {name: convert(name) for name in names}
|
return {name: convert(name) for name in names}
|
||||||
|
|
||||||
|
|
||||||
def merge_multifile_models(models_plus: List[ModelPlus]) -> ModelPlus:
|
def merge_multifile_models(models_plus: list[ModelPlus]) -> ModelPlus:
|
||||||
formats = set(mp.format for mp in models_plus)
|
formats = set(mp.format for mp in models_plus)
|
||||||
assert len(formats) == 1, "different formats?"
|
assert len(formats) == 1, "different formats?"
|
||||||
format = formats.pop()
|
format = formats.pop()
|
||||||
@ -674,7 +675,7 @@ class LazyUnpickler(pickle.Unpickler):
|
|||||||
def rebuild_from_type_v2(func, new_type, args, state):
|
def rebuild_from_type_v2(func, new_type, args, state):
|
||||||
return func(*args)
|
return func(*args)
|
||||||
|
|
||||||
CLASSES: Dict[Tuple[str, str], Any] = {
|
CLASSES: dict[tuple[str, str], Any] = {
|
||||||
# getattr used here as a workaround for mypy not being smart enough to detrmine
|
# getattr used here as a workaround for mypy not being smart enough to detrmine
|
||||||
# the staticmethods have a __func__ attribute.
|
# the staticmethods have a __func__ attribute.
|
||||||
('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'),
|
('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'),
|
||||||
@ -707,15 +708,15 @@ def lazy_load_torch_file(outer_fp: IO[bytes], path: Path) -> ModelPlus:
|
|||||||
|
|
||||||
def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus:
|
def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus:
|
||||||
header_size, = struct.unpack('<Q', fp.read(8))
|
header_size, = struct.unpack('<Q', fp.read(8))
|
||||||
header: Dict[str, Dict[str, Any]] = json.loads(fp.read(header_size))
|
header: dict[str, dict[str, Any]] = json.loads(fp.read(header_size))
|
||||||
# Use mmap for the actual data to avoid race conditions with the file offset.
|
# Use mmap for the actual data to avoid race conditions with the file offset.
|
||||||
mapped = memoryview(mmap.mmap(fp.fileno(), 0, access=mmap.ACCESS_READ))
|
mapped = memoryview(mmap.mmap(fp.fileno(), 0, access=mmap.ACCESS_READ))
|
||||||
byte_buf = mapped[8 + header_size:]
|
byte_buf = mapped[8 + header_size:]
|
||||||
|
|
||||||
def convert(info: Dict[str, Any]) -> LazyTensor:
|
def convert(info: dict[str, Any]) -> LazyTensor:
|
||||||
data_type = SAFETENSORS_DATA_TYPES[info['dtype']]
|
data_type = SAFETENSORS_DATA_TYPES[info['dtype']]
|
||||||
numpy_dtype = data_type.dtype
|
numpy_dtype = data_type.dtype
|
||||||
shape: List[int] = info['shape']
|
shape: list[int] = info['shape']
|
||||||
begin, end = info['data_offsets']
|
begin, end = info['data_offsets']
|
||||||
assert 0 <= begin <= end <= len(byte_buf)
|
assert 0 <= begin <= end <= len(byte_buf)
|
||||||
assert end - begin == math.prod(shape) * numpy_dtype.itemsize
|
assert end - begin == math.prod(shape) * numpy_dtype.itemsize
|
||||||
@ -754,7 +755,7 @@ def lazy_load_file(path: Path) -> ModelPlus:
|
|||||||
In = TypeVar('In')
|
In = TypeVar('In')
|
||||||
Out = TypeVar('Out')
|
Out = TypeVar('Out')
|
||||||
|
|
||||||
def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int, max_workers: Optional[int] = None, use_processpool_executor: bool = False) -> Iterable[Out]:
|
def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int, max_workers: int | None = None, use_processpool_executor: bool = False) -> Iterable[Out]:
|
||||||
'''Parallel map, but with backpressure. If the caller doesn't call `next`
|
'''Parallel map, but with backpressure. If the caller doesn't call `next`
|
||||||
fast enough, this will stop calling `func` at some point rather than
|
fast enough, this will stop calling `func` at some point rather than
|
||||||
letting results pile up in memory. Specifically, there is a max of one
|
letting results pile up in memory. Specifically, there is a max of one
|
||||||
@ -763,13 +764,13 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc
|
|||||||
yield from map(func, iterable)
|
yield from map(func, iterable)
|
||||||
# Not reached.
|
# Not reached.
|
||||||
iterable = iter(iterable)
|
iterable = iter(iterable)
|
||||||
executor_class: Union[Type[ThreadPoolExecutor], Type[ProcessPoolExecutor]]
|
executor_class: type[ThreadPoolExecutor] | type[ProcessPoolExecutor]
|
||||||
if use_processpool_executor:
|
if use_processpool_executor:
|
||||||
executor_class = ProcessPoolExecutor
|
executor_class = ProcessPoolExecutor
|
||||||
else:
|
else:
|
||||||
executor_class = ThreadPoolExecutor
|
executor_class = ThreadPoolExecutor
|
||||||
with executor_class(max_workers = max_workers) as executor:
|
with executor_class(max_workers = max_workers) as executor:
|
||||||
futures: List[concurrent.futures.Future[Out]] = []
|
futures: list[concurrent.futures.Future[Out]] = []
|
||||||
done = False
|
done = False
|
||||||
for _ in range(concurrency):
|
for _ in range(concurrency):
|
||||||
try:
|
try:
|
||||||
@ -893,13 +894,13 @@ class OutputFile:
|
|||||||
of.close()
|
of.close()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def do_item(item: Tuple[str, LazyTensor]) -> Tuple[DataType, NDArray]:
|
def do_item(item: tuple[str, LazyTensor]) -> tuple[DataType, NDArray]:
|
||||||
name, lazy_tensor = item
|
name, lazy_tensor = item
|
||||||
tensor = lazy_tensor.load().to_ggml()
|
tensor = lazy_tensor.load().to_ggml()
|
||||||
return (lazy_tensor.data_type, tensor.ndarray)
|
return (lazy_tensor.data_type, tensor.ndarray)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def maybe_do_quantize(item: Tuple[DataType, NDArray]) -> NDArray:
|
def maybe_do_quantize(item: tuple[DataType, NDArray]) -> NDArray:
|
||||||
dt, arr = item
|
dt, arr = item
|
||||||
if not isinstance(dt, QuantizedDataType):
|
if not isinstance(dt, QuantizedDataType):
|
||||||
return arr
|
return arr
|
||||||
@ -940,7 +941,7 @@ class OutputFile:
|
|||||||
|
|
||||||
of.close()
|
of.close()
|
||||||
|
|
||||||
def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFileType:
|
def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileType:
|
||||||
wq_type = model[NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0)+".weight"].data_type
|
wq_type = model[NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0)+".weight"].data_type
|
||||||
|
|
||||||
if output_type_str == "f32" or (output_type_str is None and wq_type == DT_F32):
|
if output_type_str == "f32" or (output_type_str is None and wq_type == DT_F32):
|
||||||
@ -960,7 +961,7 @@ def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyM
|
|||||||
|
|
||||||
def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
|
def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
|
||||||
tmap = gguf.TensorNameMap(ARCH, params.n_layer)
|
tmap = gguf.TensorNameMap(ARCH, params.n_layer)
|
||||||
should_skip: Set[gguf.MODEL_TENSOR] = set(gguf.MODEL_TENSOR_SKIP.get(ARCH, []))
|
should_skip: set[gguf.MODEL_TENSOR] = set(gguf.MODEL_TENSOR_SKIP.get(ARCH, []))
|
||||||
|
|
||||||
tmp = model
|
tmp = model
|
||||||
|
|
||||||
@ -995,12 +996,12 @@ def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def nth_multifile_path(path: Path, n: int) -> Optional[Path]:
|
def nth_multifile_path(path: Path, n: int) -> Path | None:
|
||||||
'''Given any path belonging to a multi-file model (e.g. foo.bin.1), return
|
'''Given any path belonging to a multi-file model (e.g. foo.bin.1), return
|
||||||
the nth path in the model.
|
the nth path in the model.
|
||||||
'''
|
'''
|
||||||
# Support the following patterns:
|
# Support the following patterns:
|
||||||
patterns: List[Tuple[str, str]] = [
|
patterns: list[tuple[str, str]] = [
|
||||||
# - x.00.pth, x.01.pth, etc.
|
# - x.00.pth, x.01.pth, etc.
|
||||||
(r'\.[0-9]{2}\.pth$', f'.{n:02}.pth'),
|
(r'\.[0-9]{2}\.pth$', f'.{n:02}.pth'),
|
||||||
# - x-00001-of-00002.bin, x-00002-of-00002.bin, etc.
|
# - x-00001-of-00002.bin, x-00002-of-00002.bin, etc.
|
||||||
@ -1016,11 +1017,11 @@ def nth_multifile_path(path: Path, n: int) -> Optional[Path]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def find_multifile_paths(path: Path) -> List[Path]:
|
def find_multifile_paths(path: Path) -> list[Path]:
|
||||||
'''Given any path belonging to a multi-file model (e.g. foo.bin.1), return
|
'''Given any path belonging to a multi-file model (e.g. foo.bin.1), return
|
||||||
the whole list of paths in the model.
|
the whole list of paths in the model.
|
||||||
'''
|
'''
|
||||||
ret: List[Path] = []
|
ret: list[Path] = []
|
||||||
for i in itertools.count():
|
for i in itertools.count():
|
||||||
nth_path = nth_multifile_path(path, i)
|
nth_path = nth_multifile_path(path, i)
|
||||||
if nth_path is None:
|
if nth_path is None:
|
||||||
@ -1051,7 +1052,7 @@ def load_some_model(path: Path) -> ModelPlus:
|
|||||||
path = files[0]
|
path = files[0]
|
||||||
|
|
||||||
paths = find_multifile_paths(path)
|
paths = find_multifile_paths(path)
|
||||||
models_plus: List[ModelPlus] = []
|
models_plus: list[ModelPlus] = []
|
||||||
for path in paths:
|
for path in paths:
|
||||||
print(f"Loading model file {path}")
|
print(f"Loading model file {path}")
|
||||||
models_plus.append(lazy_load_file(path))
|
models_plus.append(lazy_load_file(path))
|
||||||
@ -1060,7 +1061,7 @@ def load_some_model(path: Path) -> ModelPlus:
|
|||||||
return model_plus
|
return model_plus
|
||||||
|
|
||||||
|
|
||||||
def load_vocab(path: Path, vocabtype: Optional[str]) -> Union[BpeVocab, SentencePieceVocab]:
|
def load_vocab(path: Path, vocabtype: str | None) -> Vocab:
|
||||||
# Be extra-friendly and accept either a file or a directory. Also, if it's
|
# Be extra-friendly and accept either a file or a directory. Also, if it's
|
||||||
# a directory, it might be the model directory, and tokenizer.model might
|
# a directory, it might be the model directory, and tokenizer.model might
|
||||||
# be in the parent of that.
|
# be in the parent of that.
|
||||||
@ -1091,7 +1092,7 @@ def load_vocab(path: Path, vocabtype: Optional[str]) -> Union[BpeVocab, Sentence
|
|||||||
raise ValueError(f"Unsupported vocabulary type {vocabtype}")
|
raise ValueError(f"Unsupported vocabulary type {vocabtype}")
|
||||||
|
|
||||||
|
|
||||||
def default_outfile(model_paths: List[Path], file_type: GGMLFileType) -> Path:
|
def default_outfile(model_paths: list[Path], file_type: GGMLFileType) -> Path:
|
||||||
namestr = {
|
namestr = {
|
||||||
GGMLFileType.AllF32: "f32",
|
GGMLFileType.AllF32: "f32",
|
||||||
GGMLFileType.MostlyF16: "f16",
|
GGMLFileType.MostlyF16: "f16",
|
||||||
@ -1114,7 +1115,7 @@ def do_dump_model(model_plus: ModelPlus) -> None:
|
|||||||
print(f"{name}: shape={lazy_tensor.shape} type={lazy_tensor.data_type}; {lazy_tensor.description}")
|
print(f"{name}: shape={lazy_tensor.shape} type={lazy_tensor.data_type}; {lazy_tensor.description}")
|
||||||
|
|
||||||
|
|
||||||
def main(args_in: Optional[List[str]] = None) -> None:
|
def main(args_in: list[str] | None = None) -> None:
|
||||||
parser = argparse.ArgumentParser(description="Convert a LLaMa model to a GGML compatible file")
|
parser = argparse.ArgumentParser(description="Convert a LLaMa model to a GGML compatible file")
|
||||||
parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model")
|
parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model")
|
||||||
parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file")
|
parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file")
|
||||||
|
@ -1,16 +1,18 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import shutil
|
from __future__ import annotations
|
||||||
import sys
|
|
||||||
import struct
|
|
||||||
import tempfile
|
|
||||||
import numpy as np
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
import shutil
|
||||||
|
import struct
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
from enum import IntEnum, auto
|
from enum import IntEnum, auto
|
||||||
from io import BufferedWriter
|
from io import BufferedWriter
|
||||||
from typing import Any, BinaryIO, Callable, IO, Dict, List, Optional, Sequence, Tuple, Union
|
from pathlib import Path
|
||||||
|
from typing import IO, Any, BinaryIO, Callable, Sequence
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
#
|
#
|
||||||
# constants
|
# constants
|
||||||
@ -103,7 +105,7 @@ class MODEL_TENSOR(IntEnum):
|
|||||||
FFN_NORM : int = auto()
|
FFN_NORM : int = auto()
|
||||||
|
|
||||||
|
|
||||||
MODEL_ARCH_NAMES: Dict[MODEL_ARCH, str] = {
|
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||||
MODEL_ARCH.LLAMA: "llama",
|
MODEL_ARCH.LLAMA: "llama",
|
||||||
MODEL_ARCH.FALCON: "falcon",
|
MODEL_ARCH.FALCON: "falcon",
|
||||||
MODEL_ARCH.GPT2: "gpt2",
|
MODEL_ARCH.GPT2: "gpt2",
|
||||||
@ -112,7 +114,7 @@ MODEL_ARCH_NAMES: Dict[MODEL_ARCH, str] = {
|
|||||||
MODEL_ARCH.MPT: "mpt",
|
MODEL_ARCH.MPT: "mpt",
|
||||||
}
|
}
|
||||||
|
|
||||||
MODEL_TENSOR_NAMES: Dict[MODEL_ARCH, Dict[MODEL_TENSOR, str]] = {
|
MODEL_TENSOR_NAMES: dict[MODEL_ARCH, dict[MODEL_TENSOR, str]] = {
|
||||||
MODEL_ARCH.LLAMA: {
|
MODEL_ARCH.LLAMA: {
|
||||||
MODEL_TENSOR.TOKEN_EMBD: "token_embd",
|
MODEL_TENSOR.TOKEN_EMBD: "token_embd",
|
||||||
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
|
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
|
||||||
@ -158,7 +160,7 @@ MODEL_TENSOR_NAMES: Dict[MODEL_ARCH, Dict[MODEL_TENSOR, str]] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
# tensors that will not be serialized
|
# tensors that will not be serialized
|
||||||
MODEL_TENSOR_SKIP: Dict[MODEL_ARCH, List[MODEL_TENSOR]] = {
|
MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
MODEL_ARCH.LLAMA: [
|
MODEL_ARCH.LLAMA: [
|
||||||
MODEL_TENSOR.ROPE_FREQS,
|
MODEL_TENSOR.ROPE_FREQS,
|
||||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||||
@ -167,7 +169,7 @@ MODEL_TENSOR_SKIP: Dict[MODEL_ARCH, List[MODEL_TENSOR]] = {
|
|||||||
|
|
||||||
|
|
||||||
class TensorNameMap:
|
class TensorNameMap:
|
||||||
mappings_cfg: Dict[MODEL_TENSOR, Tuple[str, ...]] = {
|
mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
|
||||||
# Token embeddings
|
# Token embeddings
|
||||||
MODEL_TENSOR.TOKEN_EMBD: (
|
MODEL_TENSOR.TOKEN_EMBD: (
|
||||||
"gpt_neox.embed_in", # gptneox
|
"gpt_neox.embed_in", # gptneox
|
||||||
@ -203,7 +205,7 @@ class TensorNameMap:
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
block_mappings_cfg: Dict[MODEL_TENSOR, Tuple[str, ...]] = {
|
block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
|
||||||
# Attention norm
|
# Attention norm
|
||||||
MODEL_TENSOR.ATTN_NORM: (
|
MODEL_TENSOR.ATTN_NORM: (
|
||||||
"gpt_neox.layers.{bid}.input_layernorm", # gptneox
|
"gpt_neox.layers.{bid}.input_layernorm", # gptneox
|
||||||
@ -298,9 +300,9 @@ class TensorNameMap:
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
mapping: Dict[str, Tuple[MODEL_TENSOR, str]]
|
mapping: dict[str, tuple[MODEL_TENSOR, str]]
|
||||||
|
|
||||||
tensor_names: Dict[MODEL_TENSOR, str]
|
tensor_names: dict[MODEL_TENSOR, str]
|
||||||
|
|
||||||
def __init__(self, arch: MODEL_ARCH, n_blocks: int):
|
def __init__(self, arch: MODEL_ARCH, n_blocks: int):
|
||||||
mapping = self.mapping = {}
|
mapping = self.mapping = {}
|
||||||
@ -321,7 +323,7 @@ class TensorNameMap:
|
|||||||
key = key.format(bid = bid)
|
key = key.format(bid = bid)
|
||||||
mapping[key] = (tensor, tensor_name)
|
mapping[key] = (tensor, tensor_name)
|
||||||
|
|
||||||
def get_type_and_name(self, key: str, try_suffixes: Sequence[str]) -> Optional[Tuple[MODEL_TENSOR, str]]:
|
def get_type_and_name(self, key: str, try_suffixes: Sequence[str]) -> tuple[MODEL_TENSOR, str] | None:
|
||||||
result = self.mapping.get(key)
|
result = self.mapping.get(key)
|
||||||
if result is not None:
|
if result is not None:
|
||||||
return result
|
return result
|
||||||
@ -332,13 +334,13 @@ class TensorNameMap:
|
|||||||
return (result[0], result[1] + suffix)
|
return (result[0], result[1] + suffix)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_name(self, key: str, try_suffixes: Sequence[str]) -> Optional[str]:
|
def get_name(self, key: str, try_suffixes: Sequence[str]) -> str | None:
|
||||||
result = self.get_type_and_name(key, try_suffixes = try_suffixes)
|
result = self.get_type_and_name(key, try_suffixes = try_suffixes)
|
||||||
if result is None:
|
if result is None:
|
||||||
return None
|
return None
|
||||||
return result[1]
|
return result[1]
|
||||||
|
|
||||||
def get_type(self, key: str, try_suffixes: Sequence[str]) -> Optional[MODEL_TENSOR]:
|
def get_type(self, key: str, try_suffixes: Sequence[str]) -> MODEL_TENSOR | None:
|
||||||
result = self.get_type_and_name(key, try_suffixes = try_suffixes)
|
result = self.get_type_and_name(key, try_suffixes = try_suffixes)
|
||||||
if result is None:
|
if result is None:
|
||||||
return None
|
return None
|
||||||
@ -432,10 +434,10 @@ class GGUFWriter:
|
|||||||
ti_data = b""
|
ti_data = b""
|
||||||
ti_data_count = 0
|
ti_data_count = 0
|
||||||
use_temp_file: bool
|
use_temp_file: bool
|
||||||
temp_file: Optional[tempfile.SpooledTemporaryFile[bytes]] = None
|
temp_file: tempfile.SpooledTemporaryFile[bytes] | None = None
|
||||||
tensors: List[Tuple[np.ndarray[Any, Any], int]]
|
tensors: list[tuple[np.ndarray[Any, Any], int]]
|
||||||
|
|
||||||
def __init__(self, path: Union[os.PathLike[str], str], arch: str, use_temp_file = True):
|
def __init__(self, path: os.PathLike[str] | str, arch: str, use_temp_file = True):
|
||||||
self.fout = open(path, "wb")
|
self.fout = open(path, "wb")
|
||||||
self.arch = arch
|
self.arch = arch
|
||||||
self.add_architecture()
|
self.add_architecture()
|
||||||
@ -531,7 +533,7 @@ class GGUFWriter:
|
|||||||
GGUFValueType.FLOAT64: "<d",
|
GGUFValueType.FLOAT64: "<d",
|
||||||
GGUFValueType.BOOL: "?" ,
|
GGUFValueType.BOOL: "?" ,
|
||||||
}
|
}
|
||||||
def add_val(self, val: Any, vtype: Optional[GGUFValueType] = None, add_vtype: bool = True):
|
def add_val(self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool = True):
|
||||||
if vtype is None:
|
if vtype is None:
|
||||||
vtype = GGUFValueType.get_type(val)
|
vtype = GGUFValueType.get_type(val)
|
||||||
|
|
||||||
@ -561,7 +563,7 @@ class GGUFWriter:
|
|||||||
def ggml_pad(x: int, n: int) -> int:
|
def ggml_pad(x: int, n: int) -> int:
|
||||||
return ((x + n - 1) // n) * n
|
return ((x + n - 1) // n) * n
|
||||||
|
|
||||||
def add_tensor_info(self, name: str, tensor_shape: Sequence[int], tensor_dtype: Union[np.dtype[np.float16], np.dtype[np.float32]], tensor_nbytes: int, raw_dtype: Optional[GGMLQuantizationType] = None):
|
def add_tensor_info(self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype[np.float16] | np.dtype[np.float32], tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None):
|
||||||
assert raw_dtype is not None or tensor_dtype in (np.float32, np.float16), "Only F32 and F16 tensors are supported for now"
|
assert raw_dtype is not None or tensor_dtype in (np.float32, np.float16), "Only F32 and F16 tensors are supported for now"
|
||||||
|
|
||||||
encoded_name = name.encode("utf8")
|
encoded_name = name.encode("utf8")
|
||||||
@ -580,7 +582,7 @@ class GGUFWriter:
|
|||||||
self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment)
|
self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment)
|
||||||
self.ti_data_count += 1
|
self.ti_data_count += 1
|
||||||
|
|
||||||
def add_tensor(self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Optional[Sequence[int]] = None, raw_dtype: Optional[GGMLQuantizationType] = None):
|
def add_tensor(self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, raw_dtype: GGMLQuantizationType | None = None):
|
||||||
if self.use_temp_file and self.temp_file is None:
|
if self.use_temp_file and self.temp_file is None:
|
||||||
fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256*1024*1024)
|
fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256*1024*1024)
|
||||||
fp.seek(0)
|
fp.seek(0)
|
||||||
@ -600,7 +602,7 @@ class GGUFWriter:
|
|||||||
if pad != 0:
|
if pad != 0:
|
||||||
self.temp_file.write(bytes([0] * pad))
|
self.temp_file.write(bytes([0] * pad))
|
||||||
|
|
||||||
def write_padding(self, fp: BinaryIO, n: int, align: Optional[int] = None):
|
def write_padding(self, fp: BinaryIO, n: int, align: int | None = None):
|
||||||
pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n
|
pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n
|
||||||
if pad != 0:
|
if pad != 0:
|
||||||
fp.write(bytes([0] * pad))
|
fp.write(bytes([0] * pad))
|
||||||
@ -726,13 +728,13 @@ class GGUFWriter:
|
|||||||
def add_tokenizer_model(self, model: str):
|
def add_tokenizer_model(self, model: str):
|
||||||
self.add_string(KEY_TOKENIZER_MODEL, model)
|
self.add_string(KEY_TOKENIZER_MODEL, model)
|
||||||
|
|
||||||
def add_token_list(self, tokens: Union[Sequence[str], Sequence[bytes], Sequence[bytearray]]):
|
def add_token_list(self, tokens: Sequence[str] | Sequence[bytes] | Sequence[bytearray]):
|
||||||
self.add_array(KEY_TOKENIZER_LIST, tokens)
|
self.add_array(KEY_TOKENIZER_LIST, tokens)
|
||||||
|
|
||||||
def add_token_merges(self, merges: Union[Sequence[str], Sequence[bytes], Sequence[bytearray]]):
|
def add_token_merges(self, merges: Sequence[str] | Sequence[bytes] | Sequence[bytearray]):
|
||||||
self.add_array(KEY_TOKENIZER_MERGES, merges)
|
self.add_array(KEY_TOKENIZER_MERGES, merges)
|
||||||
|
|
||||||
def add_token_types(self, types: Union[Sequence[TokenType], Sequence[int]]):
|
def add_token_types(self, types: Sequence[TokenType] | Sequence[int]):
|
||||||
self.add_array(KEY_TOKENIZER_TOKEN_TYPE, types)
|
self.add_array(KEY_TOKENIZER_TOKEN_TYPE, types)
|
||||||
|
|
||||||
def add_token_scores(self, scores: Sequence[float]):
|
def add_token_scores(self, scores: Sequence[float]):
|
||||||
@ -756,11 +758,11 @@ class GGUFWriter:
|
|||||||
|
|
||||||
class SpecialVocab:
|
class SpecialVocab:
|
||||||
load_merges: bool = False
|
load_merges: bool = False
|
||||||
merges: List[str] = []
|
merges: list[str] = []
|
||||||
special_token_types: Tuple[str, ...] = tuple(('bos', 'eos', 'unk', 'sep', 'pad'))
|
special_token_types: tuple[str, ...] = ('bos', 'eos', 'unk', 'sep', 'pad')
|
||||||
special_token_ids: Dict[str, int] = {}
|
special_token_ids: dict[str, int] = {}
|
||||||
|
|
||||||
def __init__(self, path: Path, load_merges: bool = False, special_token_types: Optional[Tuple[str, ...]] = None):
|
def __init__(self, path: Path, load_merges: bool = False, special_token_types: tuple[str, ...] | None = None):
|
||||||
self.special_token_ids = {}
|
self.special_token_ids = {}
|
||||||
self.load_merges = load_merges
|
self.load_merges = load_merges
|
||||||
if special_token_types is not None:
|
if special_token_types is not None:
|
||||||
@ -821,7 +823,7 @@ class SpecialVocab:
|
|||||||
print(f'gguf: Adding {len(self.merges)} merge(s).')
|
print(f'gguf: Adding {len(self.merges)} merge(s).')
|
||||||
gw.add_token_merges(self.merges)
|
gw.add_token_merges(self.merges)
|
||||||
for typ, tokid in self.special_token_ids.items():
|
for typ, tokid in self.special_token_ids.items():
|
||||||
handler: Optional[Callable[[int], None]] = getattr(gw, f'add_{typ}_token_id', None)
|
handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None)
|
||||||
if handler is None:
|
if handler is None:
|
||||||
print(f'gguf: WARNING: No handler for special token type {typ} with id {tokid} - skipping')
|
print(f'gguf: WARNING: No handler for special token type {typ} with id {tokid} - skipping')
|
||||||
continue
|
continue
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "gguf"
|
name = "gguf"
|
||||||
version = "0.2.1"
|
version = "0.3.1"
|
||||||
description = "Write ML models in GGUF for GGML"
|
description = "Write ML models in GGUF for GGML"
|
||||||
authors = ["GGML <ggml@ggml.ai>"]
|
authors = ["GGML <ggml@ggml.ai>"]
|
||||||
packages = [
|
packages = [
|
||||||
|
Loading…
Reference in New Issue
Block a user