mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-13 14:29:52 +00:00
gguf-py : decouple adding metadata from writing in GGUFWriter (#7827)
Main changes of this PR is to consolidate GGUFWriter.add_key and GGUFWriter.add_val into GGUFWriter.add_key_value. In addition use_temp_file is now opt-in instead of opt-out defaulting to False. Also GGUFWriter now does not require output file name until when actually writing to it. And GGUFWriter doesn't really need to eagerly prepare the data layout of the metadata
This commit is contained in:
parent
fe1e3917cf
commit
ed9f252118
@ -47,7 +47,7 @@ class Model:
|
|||||||
_model_classes: dict[str, type[Model]] = {}
|
_model_classes: dict[str, type[Model]] = {}
|
||||||
|
|
||||||
dir_model: Path
|
dir_model: Path
|
||||||
ftype: int
|
ftype: gguf.LlamaFileType
|
||||||
is_big_endian: bool
|
is_big_endian: bool
|
||||||
endianess: gguf.GGUFEndian
|
endianess: gguf.GGUFEndian
|
||||||
use_temp_file: bool
|
use_temp_file: bool
|
||||||
@ -94,7 +94,7 @@ class Model:
|
|||||||
ftype_lw: str = ftype_up.lower()
|
ftype_lw: str = ftype_up.lower()
|
||||||
# allow templating the file name with the output ftype, useful with the "auto" ftype
|
# allow templating the file name with the output ftype, useful with the "auto" ftype
|
||||||
self.fname_out = fname_out.parent / fname_out.name.format(ftype_lw, outtype=ftype_lw, ftype=ftype_lw, OUTTYPE=ftype_up, FTYPE=ftype_up)
|
self.fname_out = fname_out.parent / fname_out.name.format(ftype_lw, outtype=ftype_lw, ftype=ftype_lw, OUTTYPE=ftype_up, FTYPE=ftype_up)
|
||||||
self.gguf_writer = gguf.GGUFWriter(self.fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file)
|
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __init_subclass__(cls):
|
def __init_subclass__(cls):
|
||||||
@ -324,13 +324,13 @@ class Model:
|
|||||||
|
|
||||||
def write(self):
|
def write(self):
|
||||||
self.write_tensors()
|
self.write_tensors()
|
||||||
self.gguf_writer.write_header_to_file()
|
self.gguf_writer.write_header_to_file(self.fname_out)
|
||||||
self.gguf_writer.write_kv_data_to_file()
|
self.gguf_writer.write_kv_data_to_file()
|
||||||
self.gguf_writer.write_tensors_to_file(progress=True)
|
self.gguf_writer.write_tensors_to_file(progress=True)
|
||||||
self.gguf_writer.close()
|
self.gguf_writer.close()
|
||||||
|
|
||||||
def write_vocab(self):
|
def write_vocab(self):
|
||||||
self.gguf_writer.write_header_to_file()
|
self.gguf_writer.write_header_to_file(self.fname_out)
|
||||||
self.gguf_writer.write_kv_data_to_file()
|
self.gguf_writer.write_kv_data_to_file()
|
||||||
self.gguf_writer.close()
|
self.gguf_writer.close()
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
import struct
|
import struct
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from dataclasses import dataclass
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from io import BufferedWriter
|
from io import BufferedWriter
|
||||||
from typing import IO, Any, Sequence, Mapping
|
from typing import IO, Any, Sequence, Mapping
|
||||||
@ -30,17 +31,36 @@ from .quants import quant_shape_from_byte_shape
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TensorInfo:
|
||||||
|
shape: Sequence[int]
|
||||||
|
dtype: GGMLQuantizationType
|
||||||
|
nbytes: int
|
||||||
|
tensor: np.ndarray[Any, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GGUFValue:
|
||||||
|
value: Any
|
||||||
|
type: GGUFValueType
|
||||||
|
|
||||||
|
|
||||||
class WriterState(Enum):
|
class WriterState(Enum):
|
||||||
|
NO_FILE = auto()
|
||||||
EMPTY = auto()
|
EMPTY = auto()
|
||||||
HEADER = auto()
|
HEADER = auto()
|
||||||
KV_DATA = auto()
|
KV_DATA = auto()
|
||||||
TI_DATA = auto()
|
TI_DATA = auto()
|
||||||
|
WEIGHTS = auto()
|
||||||
|
|
||||||
|
|
||||||
class GGUFWriter:
|
class GGUFWriter:
|
||||||
fout: BufferedWriter
|
fout: BufferedWriter | None
|
||||||
|
path: os.PathLike[str] | str | None
|
||||||
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
|
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
|
||||||
tensors: list[np.ndarray[Any, Any]]
|
tensors: dict[str, TensorInfo]
|
||||||
|
kv_data: dict[str, GGUFValue]
|
||||||
|
state: WriterState
|
||||||
_simple_value_packing = {
|
_simple_value_packing = {
|
||||||
GGUFValueType.UINT8: "B",
|
GGUFValueType.UINT8: "B",
|
||||||
GGUFValueType.INT8: "b",
|
GGUFValueType.INT8: "b",
|
||||||
@ -56,141 +76,140 @@ class GGUFWriter:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, path: os.PathLike[str] | str, arch: str, use_temp_file: bool = True,
|
self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False,
|
||||||
endianess: GGUFEndian = GGUFEndian.LITTLE,
|
endianess: GGUFEndian = GGUFEndian.LITTLE,
|
||||||
):
|
):
|
||||||
self.fout = open(path, "wb")
|
self.fout = None
|
||||||
|
self.path = path
|
||||||
self.arch = arch
|
self.arch = arch
|
||||||
self.endianess = endianess
|
self.endianess = endianess
|
||||||
self.offset_tensor = 0
|
|
||||||
self.data_alignment = GGUF_DEFAULT_ALIGNMENT
|
self.data_alignment = GGUF_DEFAULT_ALIGNMENT
|
||||||
self.kv_data = bytearray()
|
|
||||||
self.kv_data_count = 0
|
|
||||||
self.ti_data = bytearray()
|
|
||||||
self.ti_data_count = 0
|
|
||||||
self.ti_names = set()
|
|
||||||
self.use_temp_file = use_temp_file
|
self.use_temp_file = use_temp_file
|
||||||
self.temp_file = None
|
self.temp_file = None
|
||||||
self.tensors = []
|
self.tensors = dict()
|
||||||
|
self.kv_data = dict()
|
||||||
logger.info("gguf: This GGUF file is for {0} Endian only".format(
|
logger.info("gguf: This GGUF file is for {0} Endian only".format(
|
||||||
"Big" if self.endianess == GGUFEndian.BIG else "Little",
|
"Big" if self.endianess == GGUFEndian.BIG else "Little",
|
||||||
))
|
))
|
||||||
self.state = WriterState.EMPTY
|
self.state = WriterState.NO_FILE
|
||||||
|
|
||||||
self.add_architecture()
|
self.add_architecture()
|
||||||
|
|
||||||
def write_header_to_file(self) -> None:
|
def open_output_file(self, path: os.PathLike[str] | str | None = None) -> None:
|
||||||
|
if self.state is WriterState.EMPTY and self.fout is not None and (path is None or path == self.path):
|
||||||
|
# allow calling this multiple times as long as the path is the same
|
||||||
|
return
|
||||||
|
if self.state is not WriterState.NO_FILE:
|
||||||
|
raise ValueError(f'Expected output file to be not yet opened, got {self.state}')
|
||||||
|
|
||||||
|
if path is not None:
|
||||||
|
self.path = path
|
||||||
|
|
||||||
|
if self.path is not None:
|
||||||
|
if self.fout is not None:
|
||||||
|
self.fout.close()
|
||||||
|
self.fout = open(self.path, "wb")
|
||||||
|
self.state = WriterState.EMPTY
|
||||||
|
|
||||||
|
def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None:
|
||||||
|
self.open_output_file(path)
|
||||||
|
|
||||||
if self.state is not WriterState.EMPTY:
|
if self.state is not WriterState.EMPTY:
|
||||||
raise ValueError(f'Expected output file to be empty, got {self.state}')
|
raise ValueError(f'Expected output file to be empty, got {self.state}')
|
||||||
|
|
||||||
self._write_packed("<I", GGUF_MAGIC, skip_pack_prefix = True)
|
self._write_packed("<I", GGUF_MAGIC, skip_pack_prefix = True)
|
||||||
self._write_packed("I", GGUF_VERSION)
|
self._write_packed("I", GGUF_VERSION)
|
||||||
self._write_packed("Q", self.ti_data_count)
|
self._write_packed("Q", len(self.tensors))
|
||||||
self._write_packed("Q", self.kv_data_count)
|
self._write_packed("Q", len(self.kv_data))
|
||||||
self.flush()
|
self.flush()
|
||||||
self.state = WriterState.HEADER
|
self.state = WriterState.HEADER
|
||||||
|
|
||||||
def write_kv_data_to_file(self) -> None:
|
def write_kv_data_to_file(self) -> None:
|
||||||
if self.state is not WriterState.HEADER:
|
if self.state is not WriterState.HEADER:
|
||||||
raise ValueError(f'Expected output file to contain the header, got {self.state}')
|
raise ValueError(f'Expected output file to contain the header, got {self.state}')
|
||||||
|
assert self.fout is not None
|
||||||
|
|
||||||
self.fout.write(self.kv_data)
|
kv_data = bytearray()
|
||||||
|
|
||||||
|
for key, val in self.kv_data.items():
|
||||||
|
kv_data += self._pack_val(key, GGUFValueType.STRING, add_vtype=False)
|
||||||
|
kv_data += self._pack_val(val.value, val.type, add_vtype=True)
|
||||||
|
|
||||||
|
self.fout.write(kv_data)
|
||||||
self.flush()
|
self.flush()
|
||||||
self.state = WriterState.KV_DATA
|
self.state = WriterState.KV_DATA
|
||||||
|
|
||||||
def write_ti_data_to_file(self) -> None:
|
def write_ti_data_to_file(self) -> None:
|
||||||
if self.state is not WriterState.KV_DATA:
|
if self.state is not WriterState.KV_DATA:
|
||||||
raise ValueError(f'Expected output file to contain KV data, got {self.state}')
|
raise ValueError(f'Expected output file to contain KV data, got {self.state}')
|
||||||
|
assert self.fout is not None
|
||||||
|
|
||||||
self.fout.write(self.ti_data)
|
ti_data = bytearray()
|
||||||
|
offset_tensor = 0
|
||||||
|
|
||||||
|
for name, ti in self.tensors.items():
|
||||||
|
ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False)
|
||||||
|
n_dims = len(ti.shape)
|
||||||
|
ti_data += self._pack("I", n_dims)
|
||||||
|
for i in range(n_dims):
|
||||||
|
ti_data += self._pack("Q", ti.shape[n_dims - 1 - i])
|
||||||
|
ti_data += self._pack("I", ti.dtype)
|
||||||
|
ti_data += self._pack("Q", offset_tensor)
|
||||||
|
offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment)
|
||||||
|
|
||||||
|
self.fout.write(ti_data)
|
||||||
self.flush()
|
self.flush()
|
||||||
self.state = WriterState.TI_DATA
|
self.state = WriterState.TI_DATA
|
||||||
|
|
||||||
def add_key(self, key: str) -> None:
|
def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None:
|
||||||
self.add_val(key, GGUFValueType.STRING, add_vtype=False)
|
if key in self.kv_data:
|
||||||
|
raise ValueError(f'Duplicated key name {key!r}')
|
||||||
|
|
||||||
|
self.kv_data[key] = GGUFValue(value=val, type=vtype)
|
||||||
|
|
||||||
def add_uint8(self, key: str, val: int) -> None:
|
def add_uint8(self, key: str, val: int) -> None:
|
||||||
self.add_key(key)
|
self.add_key_value(key,val, GGUFValueType.UINT8)
|
||||||
self.add_val(val, GGUFValueType.UINT8)
|
|
||||||
|
|
||||||
def add_int8(self, key: str, val: int) -> None:
|
def add_int8(self, key: str, val: int) -> None:
|
||||||
self.add_key(key)
|
self.add_key_value(key, val, GGUFValueType.INT8)
|
||||||
self.add_val(val, GGUFValueType.INT8)
|
|
||||||
|
|
||||||
def add_uint16(self, key: str, val: int) -> None:
|
def add_uint16(self, key: str, val: int) -> None:
|
||||||
self.add_key(key)
|
self.add_key_value(key, val, GGUFValueType.UINT16)
|
||||||
self.add_val(val, GGUFValueType.UINT16)
|
|
||||||
|
|
||||||
def add_int16(self, key: str, val: int) -> None:
|
def add_int16(self, key: str, val: int) -> None:
|
||||||
self.add_key(key)
|
self.add_key_value(key, val, GGUFValueType.INT16)
|
||||||
self.add_val(val, GGUFValueType.INT16)
|
|
||||||
|
|
||||||
def add_uint32(self, key: str, val: int) -> None:
|
def add_uint32(self, key: str, val: int) -> None:
|
||||||
self.add_key(key)
|
self.add_key_value(key, val, GGUFValueType.UINT32)
|
||||||
self.add_val(val, GGUFValueType.UINT32)
|
|
||||||
|
|
||||||
def add_int32(self, key: str, val: int) -> None:
|
def add_int32(self, key: str, val: int) -> None:
|
||||||
self.add_key(key)
|
self.add_key_value(key, val, GGUFValueType.INT32)
|
||||||
self.add_val(val, GGUFValueType.INT32)
|
|
||||||
|
|
||||||
def add_float32(self, key: str, val: float) -> None:
|
def add_float32(self, key: str, val: float) -> None:
|
||||||
self.add_key(key)
|
self.add_key_value(key, val, GGUFValueType.FLOAT32)
|
||||||
self.add_val(val, GGUFValueType.FLOAT32)
|
|
||||||
|
|
||||||
def add_uint64(self, key: str, val: int) -> None:
|
def add_uint64(self, key: str, val: int) -> None:
|
||||||
self.add_key(key)
|
self.add_key_value(key, val, GGUFValueType.UINT64)
|
||||||
self.add_val(val, GGUFValueType.UINT64)
|
|
||||||
|
|
||||||
def add_int64(self, key: str, val: int) -> None:
|
def add_int64(self, key: str, val: int) -> None:
|
||||||
self.add_key(key)
|
self.add_key_value(key, val, GGUFValueType.INT64)
|
||||||
self.add_val(val, GGUFValueType.INT64)
|
|
||||||
|
|
||||||
def add_float64(self, key: str, val: float) -> None:
|
def add_float64(self, key: str, val: float) -> None:
|
||||||
self.add_key(key)
|
self.add_key_value(key, val, GGUFValueType.FLOAT64)
|
||||||
self.add_val(val, GGUFValueType.FLOAT64)
|
|
||||||
|
|
||||||
def add_bool(self, key: str, val: bool) -> None:
|
def add_bool(self, key: str, val: bool) -> None:
|
||||||
self.add_key(key)
|
self.add_key_value(key, val, GGUFValueType.BOOL)
|
||||||
self.add_val(val, GGUFValueType.BOOL)
|
|
||||||
|
|
||||||
def add_string(self, key: str, val: str) -> None:
|
def add_string(self, key: str, val: str) -> None:
|
||||||
if not val:
|
if not val:
|
||||||
return
|
return
|
||||||
self.add_key(key)
|
self.add_key_value(key, val, GGUFValueType.STRING)
|
||||||
self.add_val(val, GGUFValueType.STRING)
|
|
||||||
|
|
||||||
def add_array(self, key: str, val: Sequence[Any]) -> None:
|
def add_array(self, key: str, val: Sequence[Any]) -> None:
|
||||||
if not isinstance(val, Sequence):
|
if not isinstance(val, Sequence):
|
||||||
raise ValueError("Value must be a sequence for array type")
|
raise ValueError("Value must be a sequence for array type")
|
||||||
|
|
||||||
self.add_key(key)
|
self.add_key_value(key, val, GGUFValueType.ARRAY)
|
||||||
self.add_val(val, GGUFValueType.ARRAY)
|
|
||||||
|
|
||||||
def add_val(self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool = True) -> None:
|
|
||||||
if vtype is None:
|
|
||||||
vtype = GGUFValueType.get_type(val)
|
|
||||||
|
|
||||||
if add_vtype:
|
|
||||||
self.kv_data += self._pack("I", vtype)
|
|
||||||
self.kv_data_count += 1
|
|
||||||
|
|
||||||
pack_fmt = self._simple_value_packing.get(vtype)
|
|
||||||
if pack_fmt is not None:
|
|
||||||
self.kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL)
|
|
||||||
elif vtype == GGUFValueType.STRING:
|
|
||||||
encoded_val = val.encode("utf-8") if isinstance(val, str) else val
|
|
||||||
self.kv_data += self._pack("Q", len(encoded_val))
|
|
||||||
self.kv_data += encoded_val
|
|
||||||
elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val:
|
|
||||||
ltype = GGUFValueType.get_type(val[0])
|
|
||||||
if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]):
|
|
||||||
raise ValueError("All items in a GGUF array should be of the same type")
|
|
||||||
self.kv_data += self._pack("I", ltype)
|
|
||||||
self.kv_data += self._pack("Q", len(val))
|
|
||||||
for item in val:
|
|
||||||
self.add_val(item, add_vtype=False)
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid GGUF metadata value type or value")
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def ggml_pad(x: int, n: int) -> int:
|
def ggml_pad(x: int, n: int) -> int:
|
||||||
@ -200,16 +219,12 @@ class GGUFWriter:
|
|||||||
self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype,
|
self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype,
|
||||||
tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None,
|
tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if self.state is not WriterState.EMPTY:
|
if self.state is not WriterState.NO_FILE:
|
||||||
raise ValueError(f'Expected output file to be empty, got {self.state}')
|
raise ValueError(f'Expected output file to be not yet opened, got {self.state}')
|
||||||
|
|
||||||
if name in self.ti_names:
|
if name in self.tensors:
|
||||||
raise ValueError(f'Duplicated tensor name {name}')
|
raise ValueError(f'Duplicated tensor name {name!r}')
|
||||||
self.ti_names.add(name)
|
|
||||||
|
|
||||||
encoded_name = name.encode("utf-8")
|
|
||||||
self.ti_data += self._pack("Q", len(encoded_name))
|
|
||||||
self.ti_data += encoded_name
|
|
||||||
if raw_dtype is None:
|
if raw_dtype is None:
|
||||||
if tensor_dtype == np.float16:
|
if tensor_dtype == np.float16:
|
||||||
dtype = GGMLQuantizationType.F16
|
dtype = GGMLQuantizationType.F16
|
||||||
@ -231,14 +246,8 @@ class GGUFWriter:
|
|||||||
dtype = raw_dtype
|
dtype = raw_dtype
|
||||||
if tensor_dtype == np.uint8:
|
if tensor_dtype == np.uint8:
|
||||||
tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype)
|
tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype)
|
||||||
n_dims = len(tensor_shape)
|
|
||||||
self.ti_data += self._pack("I", n_dims)
|
self.tensors[name] = TensorInfo(shape=tensor_shape, dtype=dtype, nbytes=tensor_nbytes)
|
||||||
for i in range(n_dims):
|
|
||||||
self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i])
|
|
||||||
self.ti_data += self._pack("I", dtype)
|
|
||||||
self.ti_data += self._pack("Q", self.offset_tensor)
|
|
||||||
self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment)
|
|
||||||
self.ti_data_count += 1
|
|
||||||
|
|
||||||
def add_tensor(
|
def add_tensor(
|
||||||
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
|
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
|
||||||
@ -252,10 +261,10 @@ class GGUFWriter:
|
|||||||
self.temp_file = fp
|
self.temp_file = fp
|
||||||
|
|
||||||
shape: Sequence[int] = raw_shape if raw_shape is not None else tensor.shape
|
shape: Sequence[int] = raw_shape if raw_shape is not None else tensor.shape
|
||||||
self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype)
|
self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype=raw_dtype)
|
||||||
|
|
||||||
if self.temp_file is None:
|
if self.temp_file is None:
|
||||||
self.tensors.append(tensor)
|
self.tensors[name].tensor = tensor
|
||||||
return
|
return
|
||||||
|
|
||||||
tensor.tofile(self.temp_file)
|
tensor.tofile(self.temp_file)
|
||||||
@ -267,8 +276,9 @@ class GGUFWriter:
|
|||||||
fp.write(bytes([0] * pad))
|
fp.write(bytes([0] * pad))
|
||||||
|
|
||||||
def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
|
def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
|
||||||
if self.state is not WriterState.TI_DATA:
|
if self.state is not WriterState.TI_DATA and self.state is not WriterState.WEIGHTS:
|
||||||
raise ValueError(f'Expected output file to contain tensor info, got {self.state}')
|
raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}')
|
||||||
|
assert self.fout is not None
|
||||||
|
|
||||||
if self.endianess == GGUFEndian.BIG:
|
if self.endianess == GGUFEndian.BIG:
|
||||||
tensor.byteswap(inplace=True)
|
tensor.byteswap(inplace=True)
|
||||||
@ -276,50 +286,51 @@ class GGUFWriter:
|
|||||||
tensor.tofile(self.fout)
|
tensor.tofile(self.fout)
|
||||||
self.write_padding(self.fout, tensor.nbytes)
|
self.write_padding(self.fout, tensor.nbytes)
|
||||||
|
|
||||||
|
self.state = WriterState.WEIGHTS
|
||||||
|
|
||||||
def write_tensors_to_file(self, *, progress: bool = False) -> None:
|
def write_tensors_to_file(self, *, progress: bool = False) -> None:
|
||||||
self.write_ti_data_to_file()
|
self.write_ti_data_to_file()
|
||||||
|
|
||||||
|
assert self.fout is not None
|
||||||
|
|
||||||
self.write_padding(self.fout, self.fout.tell())
|
self.write_padding(self.fout, self.fout.tell())
|
||||||
|
|
||||||
if self.temp_file is None:
|
if self.temp_file is None:
|
||||||
self.tensors.reverse() # to pop from the "beginning" in constant time
|
bar = None
|
||||||
|
|
||||||
if progress:
|
if progress:
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
total_bytes = sum(t.nbytes for t in self.tensors)
|
total_bytes = sum(t.nbytes for t in self.tensors.values())
|
||||||
|
|
||||||
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
|
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
|
||||||
|
|
||||||
while True:
|
# relying on the fact that Python dicts preserve insertion order (since 3.7)
|
||||||
try:
|
for ti in self.tensors.values():
|
||||||
tensor = self.tensors.pop()
|
assert ti.tensor is not None # can only iterate once over the tensors
|
||||||
except IndexError:
|
assert ti.tensor.nbytes == ti.nbytes
|
||||||
break
|
ti.tensor.tofile(self.fout)
|
||||||
tensor.tofile(self.fout)
|
if bar is not None:
|
||||||
bar.update(tensor.nbytes)
|
bar.update(ti.nbytes)
|
||||||
self.write_padding(self.fout, tensor.nbytes)
|
self.write_padding(self.fout, ti.nbytes)
|
||||||
return
|
ti.tensor = None
|
||||||
while True:
|
else:
|
||||||
try:
|
self.temp_file.seek(0)
|
||||||
tensor = self.tensors.pop()
|
|
||||||
except IndexError:
|
|
||||||
break
|
|
||||||
tensor.tofile(self.fout)
|
|
||||||
self.write_padding(self.fout, tensor.nbytes)
|
|
||||||
return
|
|
||||||
|
|
||||||
self.temp_file.seek(0)
|
shutil.copyfileobj(self.temp_file, self.fout)
|
||||||
|
self.flush()
|
||||||
|
self.temp_file.close()
|
||||||
|
|
||||||
shutil.copyfileobj(self.temp_file, self.fout)
|
self.state = WriterState.WEIGHTS
|
||||||
self.flush()
|
|
||||||
self.temp_file.close()
|
|
||||||
|
|
||||||
def flush(self) -> None:
|
def flush(self) -> None:
|
||||||
|
assert self.fout is not None
|
||||||
self.fout.flush()
|
self.fout.flush()
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
self.fout.close()
|
if self.fout is not None:
|
||||||
|
self.fout.close()
|
||||||
|
self.fout = None
|
||||||
|
|
||||||
def add_architecture(self) -> None:
|
def add_architecture(self) -> None:
|
||||||
self.add_string(Keys.General.ARCHITECTURE, self.arch)
|
self.add_string(Keys.General.ARCHITECTURE, self.arch)
|
||||||
@ -449,7 +460,7 @@ class GGUFWriter:
|
|||||||
def add_rope_scaling_factor(self, value: float) -> None:
|
def add_rope_scaling_factor(self, value: float) -> None:
|
||||||
self.add_float32(Keys.Rope.SCALING_FACTOR.format(arch=self.arch), value)
|
self.add_float32(Keys.Rope.SCALING_FACTOR.format(arch=self.arch), value)
|
||||||
|
|
||||||
def add_rope_scaling_attn_factors(self, value: Sequence[float]) -> None:
|
def add_rope_scaling_attn_factors(self, value: float) -> None:
|
||||||
self.add_float32(Keys.Rope.SCALING_ATTN_FACTOR.format(arch=self.arch), value)
|
self.add_float32(Keys.Rope.SCALING_ATTN_FACTOR.format(arch=self.arch), value)
|
||||||
|
|
||||||
def add_rope_scaling_orig_ctx_len(self, value: int) -> None:
|
def add_rope_scaling_orig_ctx_len(self, value: int) -> None:
|
||||||
@ -571,5 +582,32 @@ class GGUFWriter:
|
|||||||
pack_prefix = '<' if self.endianess == GGUFEndian.LITTLE else '>'
|
pack_prefix = '<' if self.endianess == GGUFEndian.LITTLE else '>'
|
||||||
return struct.pack(f'{pack_prefix}{fmt}', value)
|
return struct.pack(f'{pack_prefix}{fmt}', value)
|
||||||
|
|
||||||
|
def _pack_val(self, val: Any, vtype: GGUFValueType, add_vtype: bool) -> bytes:
|
||||||
|
kv_data = bytearray()
|
||||||
|
|
||||||
|
if add_vtype:
|
||||||
|
kv_data += self._pack("I", vtype)
|
||||||
|
|
||||||
|
pack_fmt = self._simple_value_packing.get(vtype)
|
||||||
|
if pack_fmt is not None:
|
||||||
|
kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL)
|
||||||
|
elif vtype == GGUFValueType.STRING:
|
||||||
|
encoded_val = val.encode("utf-8") if isinstance(val, str) else val
|
||||||
|
kv_data += self._pack("Q", len(encoded_val))
|
||||||
|
kv_data += encoded_val
|
||||||
|
elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val:
|
||||||
|
ltype = GGUFValueType.get_type(val[0])
|
||||||
|
if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]):
|
||||||
|
raise ValueError("All items in a GGUF array should be of the same type")
|
||||||
|
kv_data += self._pack("I", ltype)
|
||||||
|
kv_data += self._pack("Q", len(val))
|
||||||
|
for item in val:
|
||||||
|
kv_data += self._pack_val(item, ltype, add_vtype=False)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid GGUF metadata value type or value")
|
||||||
|
|
||||||
|
return kv_data
|
||||||
|
|
||||||
def _write_packed(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> None:
|
def _write_packed(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> None:
|
||||||
|
assert self.fout is not None
|
||||||
self.fout.write(self._pack(fmt, value, skip_pack_prefix))
|
self.fout.write(self._pack(fmt, value, skip_pack_prefix))
|
||||||
|
@ -101,8 +101,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new
|
|||||||
logger.debug(f'Copying {field.name}')
|
logger.debug(f'Copying {field.name}')
|
||||||
|
|
||||||
if val.value is not None:
|
if val.value is not None:
|
||||||
writer.add_key(field.name)
|
writer.add_key_value(field.name, val.value, val.type)
|
||||||
writer.add_val(val.value, val.type)
|
|
||||||
|
|
||||||
if gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata:
|
if gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata:
|
||||||
logger.debug('Adding chat template(s)')
|
logger.debug('Adding chat template(s)')
|
||||||
@ -111,8 +110,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new
|
|||||||
|
|
||||||
for key, val in new_metadata.items():
|
for key, val in new_metadata.items():
|
||||||
logger.debug(f'Adding {key}: "{val.value}" {val.description}')
|
logger.debug(f'Adding {key}: "{val.value}" {val.description}')
|
||||||
writer.add_key(key)
|
writer.add_key_value(key, val.value, val.type)
|
||||||
writer.add_val(val.value, val.type)
|
|
||||||
|
|
||||||
total_bytes = 0
|
total_bytes = 0
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user