mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 10:24:35 +00:00
convert-hf : save memory with lazy evaluation (#7075)
* convert-hf : begin refactoring write_tensor * convert : upgrade to sentencepiece v0.2.0 * convert-hf : remove unused n_dims in extra_*_tensors * convert-hf : simplify MoE weights stacking * convert-hf : flake8 linter doesn't like semicolons * convert-hf : allow unusual model part names For example, loading `model-00001-of-00001.safetensors` now works. * convert-hf : fix stacking MoE expert tensors `torch.stack` and `torch.cat` don't do the same thing. * convert-hf : fix Mamba conversion Tested to work even with a SentencePiece-based tokenizer. * convert : use a string for the SentencePiece tokenizer path * convert-hf : display tensor shape * convert-hf : convert norms to f32 by default * convert-hf : sort model part names `os.listdir` is said to list files in arbitrary order. Sorting the file names should let "model-00009-of-00042.safetensors" be loaded before "model-00010-of-00042.safetensors". * convert-hf : use an ABC for Model again It seems Protocol can't be used as a statically type-checked ABC, because its subclasses also can't be instantiated. (why did it seem to work?) At least there's still a way to throw an error when forgetting to define the `model_arch` property of any registered Model subclasses. * convert-hf : use a plain class for Model, and forbid direct instantiation There are no abstract methods used anyway, so using ABC isn't really necessary. * convert-hf : more consistent formatting of cmdline args * convert-hf : align the message logged for converted tensors * convert-hf : fix Refact conversion * convert-hf : save memory with lazy evaluation * convert-hf : flake8 doesn't like lowercase L as a variable name * convert-hf : remove einops requirement for InternLM2 * convert-hf : faster model parts loading Instead of pre-loading them all into a dict, iterate on the tensors in the model parts progressively as needed in Model.write_tensors Conversion for some architectures relies on checking for the presence of specific tensor names, so for multi-part models, the weight map is read from the relevant json file to quickly get these names up-front. * convert-hf : minor changes for consistency * gguf-py : add tqdm as a dependency It's small, and used for a progress bar in GGUFWriter.write_tensors_to_file
This commit is contained in:
parent
bc4bba364f
commit
f98eb31c51
File diff suppressed because it is too large
Load Diff
20
convert.py
20
convert.py
@ -284,6 +284,7 @@ class Params:
|
|||||||
n_experts = None
|
n_experts = None
|
||||||
n_experts_used = None
|
n_experts_used = None
|
||||||
f_rope_freq_base = None
|
f_rope_freq_base = None
|
||||||
|
n_ff = None
|
||||||
|
|
||||||
# hack to determine LLaMA v1 vs v2 vs CodeLlama
|
# hack to determine LLaMA v1 vs v2 vs CodeLlama
|
||||||
if config.get("moe"):
|
if config.get("moe"):
|
||||||
@ -308,6 +309,8 @@ class Params:
|
|||||||
n_experts_used = config["moe"]["num_experts_per_tok"]
|
n_experts_used = config["moe"]["num_experts_per_tok"]
|
||||||
f_rope_freq_base = 1e6
|
f_rope_freq_base = 1e6
|
||||||
|
|
||||||
|
assert n_ff is not None
|
||||||
|
|
||||||
return Params(
|
return Params(
|
||||||
n_vocab = model["tok_embeddings.weight"].shape[0],
|
n_vocab = model["tok_embeddings.weight"].shape[0],
|
||||||
n_embd = config["dim"],
|
n_embd = config["dim"],
|
||||||
@ -462,7 +465,8 @@ class SentencePieceVocab(Vocab):
|
|||||||
# not found in alternate location either
|
# not found in alternate location either
|
||||||
raise FileNotFoundError('Cannot find tokenizer.model')
|
raise FileNotFoundError('Cannot find tokenizer.model')
|
||||||
|
|
||||||
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
|
self.sentencepiece_tokenizer = SentencePieceProcessor()
|
||||||
|
self.sentencepiece_tokenizer.LoadFromFile(str(fname_tokenizer))
|
||||||
vocab_size = self.sentencepiece_tokenizer.vocab_size()
|
vocab_size = self.sentencepiece_tokenizer.vocab_size()
|
||||||
|
|
||||||
new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size}
|
new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size}
|
||||||
@ -482,23 +486,23 @@ class SentencePieceVocab(Vocab):
|
|||||||
def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
||||||
tokenizer = self.sentencepiece_tokenizer
|
tokenizer = self.sentencepiece_tokenizer
|
||||||
for i in range(tokenizer.vocab_size()):
|
for i in range(tokenizer.vocab_size()):
|
||||||
piece = tokenizer.id_to_piece(i)
|
piece = tokenizer.IdToPiece(i)
|
||||||
text = piece.encode("utf-8")
|
text = piece.encode("utf-8")
|
||||||
score: float = tokenizer.get_score(i)
|
score: float = tokenizer.GetScore(i)
|
||||||
|
|
||||||
toktype = gguf.TokenType.NORMAL
|
toktype = gguf.TokenType.NORMAL
|
||||||
if tokenizer.is_unknown(i):
|
if tokenizer.IsUnknown(i):
|
||||||
toktype = gguf.TokenType.UNKNOWN
|
toktype = gguf.TokenType.UNKNOWN
|
||||||
if tokenizer.is_control(i):
|
if tokenizer.IsControl(i):
|
||||||
toktype = gguf.TokenType.CONTROL
|
toktype = gguf.TokenType.CONTROL
|
||||||
|
|
||||||
# NOTE: I think added_tokens are user defined.
|
# NOTE: I think added_tokens are user defined.
|
||||||
# ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
|
# ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
|
||||||
# if tokenizer.is_user_defined(i): toktype = gguf.TokenType.USER_DEFINED
|
# if tokenizer.is_user_defined(i): toktype = gguf.TokenType.USER_DEFINED
|
||||||
|
|
||||||
if tokenizer.is_unused(i):
|
if tokenizer.IsUnused(i):
|
||||||
toktype = gguf.TokenType.UNUSED
|
toktype = gguf.TokenType.UNUSED
|
||||||
if tokenizer.is_byte(i):
|
if tokenizer.IsByte(i):
|
||||||
toktype = gguf.TokenType.BYTE
|
toktype = gguf.TokenType.BYTE
|
||||||
|
|
||||||
yield text, score, toktype
|
yield text, score, toktype
|
||||||
@ -906,7 +910,7 @@ class LazyUnpickler(pickle.Unpickler):
|
|||||||
def rebuild_from_type_v2(func, new_type, args, state):
|
def rebuild_from_type_v2(func, new_type, args, state):
|
||||||
return func(*args)
|
return func(*args)
|
||||||
|
|
||||||
CLASSES = {
|
CLASSES: dict[tuple[str, str], type[LazyTensor] | LazyStorageKind] = {
|
||||||
# getattr used here as a workaround for mypy not being smart enough to determine
|
# getattr used here as a workaround for mypy not being smart enough to determine
|
||||||
# the staticmethods have a __func__ attribute.
|
# the staticmethods have a __func__ attribute.
|
||||||
('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'),
|
('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'),
|
||||||
|
@ -939,7 +939,7 @@ async def oai_chat_completions(user_prompt,
|
|||||||
while event_received:
|
while event_received:
|
||||||
event_received = False
|
event_received = False
|
||||||
async for line_in_bytes in response.content:
|
async for line_in_bytes in response.content:
|
||||||
line = line_in_bytes.decode('utf8')
|
line = line_in_bytes.decode('utf-8')
|
||||||
line = line.rstrip('\n').rstrip('\r')
|
line = line.rstrip('\n').rstrip('\r')
|
||||||
if line == '':
|
if line == '':
|
||||||
continue
|
continue
|
||||||
|
@ -860,7 +860,7 @@ class GGUFValueType(IntEnum):
|
|||||||
# Note: Does not support GGML_QKK_64
|
# Note: Does not support GGML_QKK_64
|
||||||
QK_K = 256
|
QK_K = 256
|
||||||
# Items here are (block size, type size)
|
# Items here are (block size, type size)
|
||||||
GGML_QUANT_SIZES = {
|
GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
|
||||||
GGMLQuantizationType.F32: (1, 4),
|
GGMLQuantizationType.F32: (1, 4),
|
||||||
GGMLQuantizationType.F16: (1, 2),
|
GGMLQuantizationType.F16: (1, 2),
|
||||||
GGMLQuantizationType.Q4_0: (32, 2 + 16),
|
GGMLQuantizationType.Q4_0: (32, 2 + 16),
|
||||||
|
@ -65,7 +65,7 @@ class ReaderTensor(NamedTuple):
|
|||||||
|
|
||||||
class GGUFReader:
|
class GGUFReader:
|
||||||
# I - same as host, S - swapped
|
# I - same as host, S - swapped
|
||||||
byte_order: Literal['I' | 'S'] = 'I'
|
byte_order: Literal['I'] | Literal['S'] = 'I'
|
||||||
alignment: int = GGUF_DEFAULT_ALIGNMENT
|
alignment: int = GGUF_DEFAULT_ALIGNMENT
|
||||||
|
|
||||||
# Note: Internal helper, API may change.
|
# Note: Internal helper, API may change.
|
||||||
@ -83,7 +83,7 @@ class GGUFReader:
|
|||||||
GGUFValueType.BOOL: np.bool_,
|
GGUFValueType.BOOL: np.bool_,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, path: os.PathLike[str] | str, mode: Literal['r' | 'r+' | 'c'] = 'r'):
|
def __init__(self, path: os.PathLike[str] | str, mode: Literal['r'] | Literal['r+'] | Literal['c'] = 'r'):
|
||||||
self.data = np.memmap(path, mode = mode)
|
self.data = np.memmap(path, mode = mode)
|
||||||
offs = 0
|
offs = 0
|
||||||
if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC:
|
if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC:
|
||||||
@ -128,7 +128,7 @@ class GGUFReader:
|
|||||||
return self.tensors[idx]
|
return self.tensors[idx]
|
||||||
|
|
||||||
def _get(
|
def _get(
|
||||||
self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I' | 'S' | '<'] = None,
|
self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I'] | Literal['S'] | Literal['<'] = None,
|
||||||
) -> npt.NDArray[Any]:
|
) -> npt.NDArray[Any]:
|
||||||
count = int(count)
|
count = int(count)
|
||||||
itemsize = int(np.empty([], dtype = dtype).itemsize)
|
itemsize = int(np.empty([], dtype = dtype).itemsize)
|
||||||
@ -250,7 +250,7 @@ class GGUFReader:
|
|||||||
raise ValueError(f'Found duplicated tensor with name {tensor_name}')
|
raise ValueError(f'Found duplicated tensor with name {tensor_name}')
|
||||||
tensor_names.add(tensor_name)
|
tensor_names.add(tensor_name)
|
||||||
ggml_type = GGMLQuantizationType(raw_dtype[0])
|
ggml_type = GGMLQuantizationType(raw_dtype[0])
|
||||||
n_elems = np.prod(dims)
|
n_elems = int(np.prod(dims))
|
||||||
block_size, type_size = GGML_QUANT_SIZES[ggml_type]
|
block_size, type_size = GGML_QUANT_SIZES[ggml_type]
|
||||||
n_bytes = n_elems * type_size // block_size
|
n_bytes = n_elems * type_size // block_size
|
||||||
data_offs = int(start_offs + offset_tensor[0])
|
data_offs = int(start_offs + offset_tensor[0])
|
||||||
|
@ -7,7 +7,7 @@ import struct
|
|||||||
import tempfile
|
import tempfile
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from io import BufferedWriter
|
from io import BufferedWriter
|
||||||
from typing import IO, Any, Sequence, Mapping
|
from typing import IO, Any, Callable, Sequence, Mapping
|
||||||
from string import ascii_letters, digits
|
from string import ascii_letters, digits
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -28,6 +28,47 @@ from .constants import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LazyTensor:
|
||||||
|
data: Callable[[], np.ndarray[Any, Any]]
|
||||||
|
# to avoid too deep recursion
|
||||||
|
functions: list[Callable[[np.ndarray[Any, Any]], np.ndarray[Any, Any]]]
|
||||||
|
dtype: np.dtype[Any]
|
||||||
|
shape: tuple[int, ...]
|
||||||
|
|
||||||
|
def __init__(self, data: Callable[[], np.ndarray[Any, Any]], *, dtype: type, shape: tuple[int, ...]):
|
||||||
|
self.data = data
|
||||||
|
self.functions = []
|
||||||
|
self.dtype = np.dtype(dtype)
|
||||||
|
self.shape = shape
|
||||||
|
|
||||||
|
def astype(self, dtype: type, **kwargs) -> LazyTensor:
|
||||||
|
self.functions.append(lambda n: n.astype(dtype, **kwargs))
|
||||||
|
self.dtype = np.dtype(dtype)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@property
|
||||||
|
def nbytes(self) -> int:
|
||||||
|
size = 1
|
||||||
|
for n in self.shape:
|
||||||
|
size *= n
|
||||||
|
return size * self.dtype.itemsize
|
||||||
|
|
||||||
|
def tofile(self, *args, **kwargs) -> None:
|
||||||
|
data = self.data()
|
||||||
|
for f in self.functions:
|
||||||
|
data = f(data)
|
||||||
|
assert data.shape == self.shape
|
||||||
|
assert data.dtype == self.dtype
|
||||||
|
assert data.nbytes == self.nbytes
|
||||||
|
self.functions = []
|
||||||
|
self.data = lambda: data
|
||||||
|
data.tofile(*args, **kwargs)
|
||||||
|
|
||||||
|
def byteswap(self, *args, **kwargs) -> LazyTensor:
|
||||||
|
self.functions.append(lambda n: n.byteswap(*args, **kwargs))
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class WriterState(Enum):
|
class WriterState(Enum):
|
||||||
EMPTY = auto()
|
EMPTY = auto()
|
||||||
HEADER = auto()
|
HEADER = auto()
|
||||||
@ -38,7 +79,7 @@ class WriterState(Enum):
|
|||||||
class GGUFWriter:
|
class GGUFWriter:
|
||||||
fout: BufferedWriter
|
fout: BufferedWriter
|
||||||
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
|
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
|
||||||
tensors: list[np.ndarray[Any, Any]]
|
tensors: list[np.ndarray[Any, Any] | LazyTensor]
|
||||||
_simple_value_packing = {
|
_simple_value_packing = {
|
||||||
GGUFValueType.UINT8: "B",
|
GGUFValueType.UINT8: "B",
|
||||||
GGUFValueType.INT8: "b",
|
GGUFValueType.INT8: "b",
|
||||||
@ -176,7 +217,7 @@ class GGUFWriter:
|
|||||||
if pack_fmt is not None:
|
if pack_fmt is not None:
|
||||||
self.kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL)
|
self.kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL)
|
||||||
elif vtype == GGUFValueType.STRING:
|
elif vtype == GGUFValueType.STRING:
|
||||||
encoded_val = val.encode("utf8") if isinstance(val, str) else val
|
encoded_val = val.encode("utf-8") if isinstance(val, str) else val
|
||||||
self.kv_data += self._pack("Q", len(encoded_val))
|
self.kv_data += self._pack("Q", len(encoded_val))
|
||||||
self.kv_data += encoded_val
|
self.kv_data += encoded_val
|
||||||
elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val:
|
elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val:
|
||||||
@ -205,7 +246,7 @@ class GGUFWriter:
|
|||||||
raise ValueError(f'Duplicated tensor name {name}')
|
raise ValueError(f'Duplicated tensor name {name}')
|
||||||
self.ti_names.add(name)
|
self.ti_names.add(name)
|
||||||
|
|
||||||
encoded_name = name.encode("utf8")
|
encoded_name = name.encode("utf-8")
|
||||||
self.ti_data += self._pack("Q", len(encoded_name))
|
self.ti_data += self._pack("Q", len(encoded_name))
|
||||||
self.ti_data += encoded_name
|
self.ti_data += encoded_name
|
||||||
n_dims = len(tensor_shape)
|
n_dims = len(tensor_shape)
|
||||||
@ -237,7 +278,7 @@ class GGUFWriter:
|
|||||||
self.ti_data_count += 1
|
self.ti_data_count += 1
|
||||||
|
|
||||||
def add_tensor(
|
def add_tensor(
|
||||||
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
|
self, name: str, tensor: np.ndarray[Any, Any] | LazyTensor, raw_shape: Sequence[int] | None = None,
|
||||||
raw_dtype: GGMLQuantizationType | None = None,
|
raw_dtype: GGMLQuantizationType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if self.endianess == GGUFEndian.BIG:
|
if self.endianess == GGUFEndian.BIG:
|
||||||
@ -262,7 +303,7 @@ class GGUFWriter:
|
|||||||
if pad != 0:
|
if pad != 0:
|
||||||
fp.write(bytes([0] * pad))
|
fp.write(bytes([0] * pad))
|
||||||
|
|
||||||
def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
|
def write_tensor_data(self, tensor: np.ndarray[Any, Any] | LazyTensor) -> None:
|
||||||
if self.state is not WriterState.TI_DATA:
|
if self.state is not WriterState.TI_DATA:
|
||||||
raise ValueError(f'Expected output file to contain tensor info, got {self.state}')
|
raise ValueError(f'Expected output file to contain tensor info, got {self.state}')
|
||||||
|
|
||||||
@ -272,15 +313,33 @@ class GGUFWriter:
|
|||||||
tensor.tofile(self.fout)
|
tensor.tofile(self.fout)
|
||||||
self.write_padding(self.fout, tensor.nbytes)
|
self.write_padding(self.fout, tensor.nbytes)
|
||||||
|
|
||||||
def write_tensors_to_file(self) -> None:
|
def write_tensors_to_file(self, *, progress: bool = False) -> None:
|
||||||
self.write_ti_data_to_file()
|
self.write_ti_data_to_file()
|
||||||
|
|
||||||
self.write_padding(self.fout, self.fout.tell())
|
self.write_padding(self.fout, self.fout.tell())
|
||||||
|
|
||||||
if self.temp_file is None:
|
if self.temp_file is None:
|
||||||
|
self.tensors.reverse() # to pop from the "beginning" in constant time
|
||||||
|
|
||||||
|
if progress:
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
total_bytes = sum(t.nbytes for t in self.tensors)
|
||||||
|
|
||||||
|
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
tensor = self.tensors.pop()
|
||||||
|
except IndexError:
|
||||||
|
break
|
||||||
|
tensor.tofile(self.fout)
|
||||||
|
bar.update(tensor.nbytes)
|
||||||
|
self.write_padding(self.fout, tensor.nbytes)
|
||||||
|
return
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
tensor = self.tensors.pop(0)
|
tensor = self.tensors.pop()
|
||||||
except IndexError:
|
except IndexError:
|
||||||
break
|
break
|
||||||
tensor.tofile(self.fout)
|
tensor.tofile(self.fout)
|
||||||
@ -479,7 +538,7 @@ class GGUFWriter:
|
|||||||
self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)
|
self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)
|
||||||
|
|
||||||
def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:
|
def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:
|
||||||
if isinstance(value, list):
|
if not isinstance(value, str):
|
||||||
template_default = None
|
template_default = None
|
||||||
template_names = set()
|
template_names = set()
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ import logging
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable, Sequence, Mapping, Iterable
|
||||||
|
|
||||||
from .gguf_writer import GGUFWriter
|
from .gguf_writer import GGUFWriter
|
||||||
|
|
||||||
@ -15,11 +15,11 @@ class SpecialVocab:
|
|||||||
merges: list[str]
|
merges: list[str]
|
||||||
add_special_token: dict[str, bool]
|
add_special_token: dict[str, bool]
|
||||||
special_token_ids: dict[str, int]
|
special_token_ids: dict[str, int]
|
||||||
chat_template: str | None
|
chat_template: str | Sequence[Mapping[str, str]] | None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, path: str | os.PathLike[str], load_merges: bool = False,
|
self, path: str | os.PathLike[str], load_merges: bool = False,
|
||||||
special_token_types: tuple[str, ...] | None = None,
|
special_token_types: Iterable[str] | None = None,
|
||||||
n_vocab: int | None = None,
|
n_vocab: int | None = None,
|
||||||
):
|
):
|
||||||
self.special_token_ids = {}
|
self.special_token_ids = {}
|
||||||
|
@ -21,6 +21,7 @@ classifiers = [
|
|||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.8"
|
python = ">=3.8"
|
||||||
numpy = ">=1.17"
|
numpy = ">=1.17"
|
||||||
|
tqdm = ">=4.27"
|
||||||
|
|
||||||
[tool.poetry.dev-dependencies]
|
[tool.poetry.dev-dependencies]
|
||||||
pytest = "^5.2"
|
pytest = "^5.2"
|
||||||
|
@ -47,7 +47,7 @@ def dump_metadata(reader: GGUFReader, args: argparse.Namespace) -> None:
|
|||||||
if len(field.types) == 1:
|
if len(field.types) == 1:
|
||||||
curr_type = field.types[0]
|
curr_type = field.types[0]
|
||||||
if curr_type == GGUFValueType.STRING:
|
if curr_type == GGUFValueType.STRING:
|
||||||
log_message += ' = {0}'.format(repr(str(bytes(field.parts[-1]), encoding='utf8')[:60]))
|
log_message += ' = {0}'.format(repr(str(bytes(field.parts[-1]), encoding='utf-8')[:60]))
|
||||||
elif field.types[0] in reader.gguf_scalar_to_np:
|
elif field.types[0] in reader.gguf_scalar_to_np:
|
||||||
log_message += ' = {0}'.format(field.parts[-1][0])
|
log_message += ' = {0}'.format(field.parts[-1][0])
|
||||||
print(log_message) # noqa: NP100
|
print(log_message) # noqa: NP100
|
||||||
|
@ -7,7 +7,7 @@ import json
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Any, Mapping, Sequence
|
from typing import Any, Sequence
|
||||||
|
|
||||||
# Necessary to load the local gguf package
|
# Necessary to load the local gguf package
|
||||||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
|
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
|
||||||
@ -34,7 +34,7 @@ def get_byteorder(reader: gguf.GGUFReader) -> gguf.GGUFEndian:
|
|||||||
return host_endian
|
return host_endian
|
||||||
|
|
||||||
|
|
||||||
def decode_field(field: gguf.ReaderField) -> Any:
|
def decode_field(field: gguf.ReaderField | None) -> Any:
|
||||||
if field and field.types:
|
if field and field.types:
|
||||||
main_type = field.types[0]
|
main_type = field.types[0]
|
||||||
|
|
||||||
@ -42,11 +42,11 @@ def decode_field(field: gguf.ReaderField) -> Any:
|
|||||||
sub_type = field.types[-1]
|
sub_type = field.types[-1]
|
||||||
|
|
||||||
if sub_type == gguf.GGUFValueType.STRING:
|
if sub_type == gguf.GGUFValueType.STRING:
|
||||||
return [str(bytes(field.parts[idx]), encoding='utf8') for idx in field.data]
|
return [str(bytes(field.parts[idx]), encoding='utf-8') for idx in field.data]
|
||||||
else:
|
else:
|
||||||
return [pv for idx in field.data for pv in field.parts[idx].tolist()]
|
return [pv for idx in field.data for pv in field.parts[idx].tolist()]
|
||||||
if main_type == gguf.GGUFValueType.STRING:
|
if main_type == gguf.GGUFValueType.STRING:
|
||||||
return str(bytes(field.parts[-1]), encoding='utf8')
|
return str(bytes(field.parts[-1]), encoding='utf-8')
|
||||||
else:
|
else:
|
||||||
return field.parts[-1][0]
|
return field.parts[-1][0]
|
||||||
|
|
||||||
@ -59,7 +59,7 @@ def get_field_data(reader: gguf.GGUFReader, key: str) -> Any:
|
|||||||
return decode_field(field)
|
return decode_field(field)
|
||||||
|
|
||||||
|
|
||||||
def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: Mapping[str, str], remove_metadata: Sequence[str]) -> None:
|
def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: dict[str, str], remove_metadata: Sequence[str]) -> None:
|
||||||
for field in reader.fields.values():
|
for field in reader.fields.values():
|
||||||
# Suppress virtual fields and fields written by GGUFWriter
|
# Suppress virtual fields and fields written by GGUFWriter
|
||||||
if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'):
|
if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'):
|
||||||
@ -101,7 +101,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new
|
|||||||
|
|
||||||
for tensor in reader.tensors:
|
for tensor in reader.tensors:
|
||||||
# Dimensions are written in reverse order, so flip them first
|
# Dimensions are written in reverse order, so flip them first
|
||||||
shape = np.flipud(tensor.shape)
|
shape = np.flipud(tensor.shape).tolist()
|
||||||
writer.add_tensor_info(tensor.name, shape, tensor.data.dtype, tensor.data.nbytes, tensor.tensor_type)
|
writer.add_tensor_info(tensor.name, shape, tensor.data.dtype, tensor.data.nbytes, tensor.tensor_type)
|
||||||
|
|
||||||
writer.write_header_to_file()
|
writer.write_header_to_file()
|
||||||
|
3
pyrightconfig.json
Normal file
3
pyrightconfig.json
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
{
|
||||||
|
"extraPaths": ["gguf-py"],
|
||||||
|
}
|
@ -1,3 +1,2 @@
|
|||||||
-r ./requirements-convert.txt
|
-r ./requirements-convert.txt
|
||||||
torch~=2.1.1
|
torch~=2.1.1
|
||||||
einops~=0.7.0
|
|
||||||
|
@ -1,3 +1,2 @@
|
|||||||
-r ./requirements-convert.txt
|
-r ./requirements-convert.txt
|
||||||
torch~=2.1.1
|
torch~=2.1.1
|
||||||
einops~=0.7.0
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
numpy~=1.24.4
|
numpy~=1.24.4
|
||||||
sentencepiece~=0.1.98
|
sentencepiece~=0.2.0
|
||||||
transformers>=4.40.1,<5.0.0
|
transformers>=4.40.1,<5.0.0
|
||||||
gguf>=0.1.0
|
gguf>=0.1.0
|
||||||
protobuf>=4.21.0,<5.0.0
|
protobuf>=4.21.0,<5.0.0
|
||||||
|
Loading…
Reference in New Issue
Block a user