mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-11 13:30:35 +00:00
Option to split during conversion (#6942)
* support splits in convert.py * Support split by size and dry run to write estimated shards/filesizes * Move split functionality to new GGUFManager class * fix improper function signature * tentative push of convert-hf-to-gguf support * resolve merge + SplitArguments for easier parsing * Fix eager tensor memory leak and remove convert.py changes Removed a memory leak caused by unexpected reference retention to eager tensors. Also removed GGUFManager functionality in convert.py in favor of specializing for convert-hf-to-gguf.py. * refactor SplitStrategy to be a deque Instead of having SplitStrategy have a `data` field that is a deque, just have SplitStrategy be a subclass of deque itself. * fix Q8 quantization * remove unnecessary imports in gguf_manager * fix final? merge issue * fix gguf_writer placement and remove comments * oops, actually fix gguf_writer placement * reduce duplicated code from gguf_writer * further simplify GGUFManager * simplify even further and standardize with GGUFWriter * reduce diffs with master * form shards while adding tensors, SHA256 sums agree with master * re-add type hint Co-authored-by: compilade <git@compilade.net> * GGUFWriter compatibility fix Co-authored-by: compilade <git@compilade.net> * Shard dataclass and un-negative dont_add_architecture * type consistency in format_n_bytes_to_str * move kv keys to constants.py * make pathlib explicit * base-1024 bytes to base-1000 * rename GGUFManager to GGUFWriterSplit * Update gguf-py/gguf/constants.py Co-authored-by: compilade <git@compilade.net> * fix convert-hf-to-gguf.py permissions * fix line endings * Update gguf-py/gguf/gguf_writer_split.py Co-authored-by: compilade <git@compilade.net> * convert-hf : restore executable file permission * examples/convert-legacy-llama.py: restore executable file permission * reinstate original gguf package import and fix type annotation * attempt to appease the linter * attempt 2 to appease the linter * attempt 3 to appease the linter * comma consistency * Update convert-hf-to-gguf.py Co-authored-by: compilade <git@compilade.net> * edit cmd line args * use simplification from #7827 * kv/ti data are still wrong * try to refactor kv data (still fails) * fix ti data messiness * tidy up * fix linting * actually make the linter happy * cleanup round 1 * remove SplitStrategy, SplitArguments * appease linter * fix typing and clean up * fix linting * Update gguf-py/gguf/gguf_writer.py Co-authored-by: compilade <git@compilade.net> * progress bar, fix split logic * Update gguf-py/gguf/gguf_writer.py Co-authored-by: compilade <git@compilade.net> * catch oversights * Update gguf-py/gguf/gguf_writer.py Co-authored-by: compilade <git@compilade.net> * Update gguf-py/gguf/gguf_writer.py Co-authored-by: compilade <git@compilade.net> * Update gguf-py/gguf/gguf_writer.py Co-authored-by: compilade <git@compilade.net> * Update gguf-py/gguf/gguf_writer.py Co-authored-by: compilade <git@compilade.net> * Update gguf-py/gguf/gguf_writer.py Co-authored-by: compilade <git@compilade.net> * swap bar orders * Update gguf-py/gguf/gguf_writer.py Co-authored-by: compilade <git@compilade.net> * Update gguf-py/gguf/gguf_writer.py Co-authored-by: compilade <git@compilade.net> * compatibility fix * Update gguf-py/gguf/gguf_writer.py Co-authored-by: compilade <git@compilade.net> * Update convert-hf-to-gguf.py Co-authored-by: compilade <git@compilade.net> --------- Co-authored-by: Brian <mofosyne@gmail.com> Co-authored-by: compilade <git@compilade.net>
This commit is contained in:
parent
8cb508d0d5
commit
52fc8705a0
@ -65,7 +65,8 @@ class Model:
|
||||
# subclasses should define this!
|
||||
model_arch: gguf.MODEL_ARCH
|
||||
|
||||
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool, model_name: str | None):
|
||||
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool,
|
||||
model_name: str | None, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
|
||||
if type(self) is Model:
|
||||
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
|
||||
self.dir_model = dir_model
|
||||
@ -96,7 +97,8 @@ class Model:
|
||||
ftype_lw: str = ftype_up.lower()
|
||||
# allow templating the file name with the output ftype, useful with the "auto" ftype
|
||||
self.fname_out = fname_out.parent / fname_out.name.format(ftype_lw, outtype=ftype_lw, ftype=ftype_lw, OUTTYPE=ftype_up, FTYPE=ftype_up)
|
||||
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file)
|
||||
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
|
||||
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
|
||||
|
||||
@classmethod
|
||||
def __init_subclass__(cls):
|
||||
@ -332,6 +334,8 @@ class Model:
|
||||
self.gguf_writer.close()
|
||||
|
||||
def write_vocab(self):
|
||||
if len(self.gguf_writer.tensors) != 1:
|
||||
raise ValueError('Splitting the vocabulary is not supported')
|
||||
self.gguf_writer.write_header_to_file(self.fname_out)
|
||||
self.gguf_writer.write_kv_data_to_file()
|
||||
self.gguf_writer.close()
|
||||
@ -2974,10 +2978,44 @@ def parse_args() -> argparse.Namespace:
|
||||
"--verbose", action="store_true",
|
||||
help="increase output verbosity",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--split-max-tensors", type=int, default=0,
|
||||
help="max tensors in each split",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--split-max-size", type=str, default="0",
|
||||
help="max size per split N(M|G)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run", action="store_true",
|
||||
help="only print out a split plan and exit, without writing any new files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-tensor-first-split", action="store_true",
|
||||
help="do not add tensors to the first split (disabled by default)"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def split_str_to_n_bytes(split_str: str) -> int:
|
||||
if split_str.endswith("K"):
|
||||
n = int(split_str[:-1]) * 1000
|
||||
elif split_str.endswith("M"):
|
||||
n = int(split_str[:-1]) * 1000 * 1000
|
||||
elif split_str.endswith("G"):
|
||||
n = int(split_str[:-1]) * 1000 * 1000 * 1000
|
||||
elif split_str.isnumeric():
|
||||
n = int(split_str)
|
||||
else:
|
||||
raise ValueError(f"Invalid split size: {split_str}, must be a number, optionally followed by K, M, or G")
|
||||
|
||||
if n < 0:
|
||||
raise ValueError(f"Invalid split size: {split_str}, must be positive")
|
||||
|
||||
return n
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
@ -3010,6 +3048,10 @@ def main() -> None:
|
||||
"auto": gguf.LlamaFileType.GUESSED,
|
||||
}
|
||||
|
||||
if args.use_temp_file and (args.split_max_tensors > 0 or args.split_max_size != "0"):
|
||||
logger.error("Error: Cannot use temp file when splitting")
|
||||
sys.exit(1)
|
||||
|
||||
if args.outfile is not None:
|
||||
fname_out = args.outfile
|
||||
else:
|
||||
@ -3027,7 +3069,10 @@ def main() -> None:
|
||||
logger.error(f"Model {hparams['architectures'][0]} is not supported")
|
||||
sys.exit(1)
|
||||
|
||||
model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file, args.no_lazy, args.model_name)
|
||||
model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file,
|
||||
args.no_lazy, args.model_name, split_max_tensors=args.split_max_tensors,
|
||||
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
|
||||
small_first_shard=args.no_tensor_first_split)
|
||||
|
||||
logger.info("Set model parameters")
|
||||
model_instance.set_gguf_parameters()
|
||||
@ -3038,13 +3083,13 @@ def main() -> None:
|
||||
model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
|
||||
|
||||
if args.vocab_only:
|
||||
logger.info(f"Exporting model vocab to '{model_instance.fname_out}'")
|
||||
logger.info("Exporting model vocab...")
|
||||
model_instance.write_vocab()
|
||||
logger.info("Model vocab successfully exported.")
|
||||
else:
|
||||
logger.info(f"Exporting model to '{model_instance.fname_out}'")
|
||||
logger.info("Exporting model...")
|
||||
model_instance.write()
|
||||
|
||||
logger.info(f"Model successfully exported to '{model_instance.fname_out}'")
|
||||
logger.info("Model successfully exported.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -75,6 +75,11 @@ class Keys:
|
||||
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
|
||||
SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier"
|
||||
|
||||
class Split:
|
||||
LLM_KV_SPLIT_NO = "split.no"
|
||||
LLM_KV_SPLIT_COUNT = "split.count"
|
||||
LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"
|
||||
|
||||
class SSM:
|
||||
CONV_KERNEL = "{arch}.ssm.conv_kernel"
|
||||
INNER_SIZE = "{arch}.ssm.inner_size"
|
||||
|
@ -7,6 +7,7 @@ import struct
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from pathlib import Path
|
||||
from io import BufferedWriter
|
||||
from typing import IO, Any, Sequence, Mapping
|
||||
from string import ascii_letters, digits
|
||||
@ -31,6 +32,9 @@ from .quants import quant_shape_from_byte_shape
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorInfo:
|
||||
shape: Sequence[int]
|
||||
@ -55,11 +59,11 @@ class WriterState(Enum):
|
||||
|
||||
|
||||
class GGUFWriter:
|
||||
fout: BufferedWriter | None
|
||||
path: os.PathLike[str] | str | None
|
||||
fout: list[BufferedWriter] | None
|
||||
path: Path | None
|
||||
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
|
||||
tensors: dict[str, TensorInfo]
|
||||
kv_data: dict[str, GGUFValue]
|
||||
tensors: list[dict[str, TensorInfo]]
|
||||
kv_data: list[dict[str, GGUFValue]]
|
||||
state: WriterState
|
||||
_simple_value_packing = {
|
||||
GGUFValueType.UINT8: "B",
|
||||
@ -76,26 +80,38 @@ class GGUFWriter:
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False,
|
||||
endianess: GGUFEndian = GGUFEndian.LITTLE,
|
||||
self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE,
|
||||
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False
|
||||
):
|
||||
self.fout = None
|
||||
self.path = path
|
||||
self.path = Path(path) if path else None
|
||||
self.arch = arch
|
||||
self.endianess = endianess
|
||||
self.data_alignment = GGUF_DEFAULT_ALIGNMENT
|
||||
self.use_temp_file = use_temp_file
|
||||
self.temp_file = None
|
||||
self.tensors = dict()
|
||||
self.kv_data = dict()
|
||||
self.tensors = [{}]
|
||||
self.kv_data = [{}]
|
||||
self.split_max_tensors = split_max_tensors
|
||||
self.split_max_size = split_max_size
|
||||
self.dry_run = dry_run
|
||||
self.small_first_shard = small_first_shard
|
||||
logger.info("gguf: This GGUF file is for {0} Endian only".format(
|
||||
"Big" if self.endianess == GGUFEndian.BIG else "Little",
|
||||
))
|
||||
self.state = WriterState.NO_FILE
|
||||
|
||||
if self.small_first_shard:
|
||||
self.tensors.append({})
|
||||
|
||||
self.add_architecture()
|
||||
|
||||
def open_output_file(self, path: os.PathLike[str] | str | None = None) -> None:
|
||||
def format_shard_names(self, path: Path) -> list[Path]:
|
||||
if len(self.tensors) == 1:
|
||||
return [path]
|
||||
return [path.with_name(SHARD_NAME_FORMAT.format(path.stem, i + 1, len(self.tensors))) for i in range(len(self.tensors))]
|
||||
|
||||
def open_output_file(self, path: Path | None = None) -> None:
|
||||
if self.state is WriterState.EMPTY and self.fout is not None and (path is None or path == self.path):
|
||||
# allow calling this multiple times as long as the path is the same
|
||||
return
|
||||
@ -106,22 +122,58 @@ class GGUFWriter:
|
||||
self.path = path
|
||||
|
||||
if self.path is not None:
|
||||
if self.fout is not None:
|
||||
self.fout.close()
|
||||
self.fout = open(self.path, "wb")
|
||||
filenames = self.print_plan()
|
||||
self.fout = [open(filename, "wb") for filename in filenames]
|
||||
self.state = WriterState.EMPTY
|
||||
|
||||
def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None:
|
||||
def print_plan(self) -> list[Path]:
|
||||
logger.info("Writing the following files:")
|
||||
assert self.path is not None
|
||||
filenames = self.format_shard_names(self.path)
|
||||
assert len(filenames) == len(self.tensors)
|
||||
for name, tensors in zip(filenames, self.tensors):
|
||||
logger.info(f"{name}: n_tensors = {len(tensors)}, total_size = {GGUFWriter.format_n_bytes_to_str(sum(ti.nbytes for ti in tensors.values()))}")
|
||||
|
||||
if self.dry_run:
|
||||
logger.info("Dry run, not writing files")
|
||||
exit()
|
||||
|
||||
return filenames
|
||||
|
||||
def add_shard_kv_data(self) -> None:
|
||||
if len(self.tensors) == 1:
|
||||
return
|
||||
|
||||
total_tensors = sum(len(t) for t in self.tensors)
|
||||
assert self.fout is not None
|
||||
total_splits = len(self.fout)
|
||||
self.kv_data.extend({} for _ in range(len(self.kv_data), total_splits))
|
||||
for i, kv_data in enumerate(self.kv_data):
|
||||
kv_data[Keys.Split.LLM_KV_SPLIT_NO] = GGUFValue(i, GGUFValueType.UINT16)
|
||||
kv_data[Keys.Split.LLM_KV_SPLIT_COUNT] = GGUFValue(total_splits, GGUFValueType.UINT16)
|
||||
kv_data[Keys.Split.LLM_KV_SPLIT_TENSORS_COUNT] = GGUFValue(total_tensors, GGUFValueType.INT32)
|
||||
|
||||
def write_header_to_file(self, path: Path | None = None) -> None:
|
||||
if len(self.tensors) == 1 and (self.split_max_tensors != 0 or self.split_max_size != 0):
|
||||
logger.warning("Model fails split requirements, not splitting")
|
||||
|
||||
self.open_output_file(path)
|
||||
|
||||
if self.state is not WriterState.EMPTY:
|
||||
raise ValueError(f'Expected output file to be empty, got {self.state}')
|
||||
|
||||
self._write_packed("<I", GGUF_MAGIC, skip_pack_prefix = True)
|
||||
self._write_packed("I", GGUF_VERSION)
|
||||
self._write_packed("Q", len(self.tensors))
|
||||
self._write_packed("Q", len(self.kv_data))
|
||||
self.flush()
|
||||
assert self.fout is not None
|
||||
assert len(self.fout) == len(self.tensors)
|
||||
assert len(self.kv_data) == 1
|
||||
|
||||
self.add_shard_kv_data()
|
||||
|
||||
for fout, tensors, kv_data in zip(self.fout, self.tensors, self.kv_data):
|
||||
fout.write(self._pack("<I", GGUF_MAGIC, skip_pack_prefix = True))
|
||||
fout.write(self._pack("I", GGUF_VERSION))
|
||||
fout.write(self._pack("Q", len(tensors)))
|
||||
fout.write(self._pack("Q", len(kv_data)))
|
||||
fout.flush()
|
||||
self.state = WriterState.HEADER
|
||||
|
||||
def write_kv_data_to_file(self) -> None:
|
||||
@ -129,13 +181,15 @@ class GGUFWriter:
|
||||
raise ValueError(f'Expected output file to contain the header, got {self.state}')
|
||||
assert self.fout is not None
|
||||
|
||||
kv_data = bytearray()
|
||||
for fout, kv_data in zip(self.fout, self.kv_data):
|
||||
kv_bytes = bytearray()
|
||||
|
||||
for key, val in self.kv_data.items():
|
||||
kv_data += self._pack_val(key, GGUFValueType.STRING, add_vtype=False)
|
||||
kv_data += self._pack_val(val.value, val.type, add_vtype=True)
|
||||
for key, val in kv_data.items():
|
||||
kv_bytes += self._pack_val(key, GGUFValueType.STRING, add_vtype=False)
|
||||
kv_bytes += self._pack_val(val.value, val.type, add_vtype=True)
|
||||
|
||||
fout.write(kv_bytes)
|
||||
|
||||
self.fout.write(kv_data)
|
||||
self.flush()
|
||||
self.state = WriterState.KV_DATA
|
||||
|
||||
@ -144,28 +198,29 @@ class GGUFWriter:
|
||||
raise ValueError(f'Expected output file to contain KV data, got {self.state}')
|
||||
assert self.fout is not None
|
||||
|
||||
ti_data = bytearray()
|
||||
offset_tensor = 0
|
||||
for fout, tensors in zip(self.fout, self.tensors):
|
||||
ti_data = bytearray()
|
||||
offset_tensor = 0
|
||||
|
||||
for name, ti in self.tensors.items():
|
||||
ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False)
|
||||
n_dims = len(ti.shape)
|
||||
ti_data += self._pack("I", n_dims)
|
||||
for i in range(n_dims):
|
||||
ti_data += self._pack("Q", ti.shape[n_dims - 1 - i])
|
||||
ti_data += self._pack("I", ti.dtype)
|
||||
ti_data += self._pack("Q", offset_tensor)
|
||||
offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment)
|
||||
for name, ti in tensors.items():
|
||||
ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False)
|
||||
n_dims = len(ti.shape)
|
||||
ti_data += self._pack("I", n_dims)
|
||||
for j in range(n_dims):
|
||||
ti_data += self._pack("Q", ti.shape[n_dims - 1 - j])
|
||||
ti_data += self._pack("I", ti.dtype)
|
||||
ti_data += self._pack("Q", offset_tensor)
|
||||
offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment)
|
||||
|
||||
self.fout.write(ti_data)
|
||||
self.flush()
|
||||
fout.write(ti_data)
|
||||
fout.flush()
|
||||
self.state = WriterState.TI_DATA
|
||||
|
||||
def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None:
|
||||
if key in self.kv_data:
|
||||
if any(key in kv_data for kv_data in self.kv_data):
|
||||
raise ValueError(f'Duplicated key name {key!r}')
|
||||
|
||||
self.kv_data[key] = GGUFValue(value=val, type=vtype)
|
||||
self.kv_data[0][key] = GGUFValue(value=val, type=vtype)
|
||||
|
||||
def add_uint8(self, key: str, val: int) -> None:
|
||||
self.add_key_value(key,val, GGUFValueType.UINT8)
|
||||
@ -206,9 +261,6 @@ class GGUFWriter:
|
||||
self.add_key_value(key, val, GGUFValueType.STRING)
|
||||
|
||||
def add_array(self, key: str, val: Sequence[Any]) -> None:
|
||||
if not isinstance(val, Sequence):
|
||||
raise ValueError("Value must be a sequence for array type")
|
||||
|
||||
self.add_key_value(key, val, GGUFValueType.ARRAY)
|
||||
|
||||
@staticmethod
|
||||
@ -222,7 +274,7 @@ class GGUFWriter:
|
||||
if self.state is not WriterState.NO_FILE:
|
||||
raise ValueError(f'Expected output file to be not yet opened, got {self.state}')
|
||||
|
||||
if name in self.tensors:
|
||||
if any(name in tensors for tensors in self.tensors):
|
||||
raise ValueError(f'Duplicated tensor name {name!r}')
|
||||
|
||||
if raw_dtype is None:
|
||||
@ -247,7 +299,18 @@ class GGUFWriter:
|
||||
if tensor_dtype == np.uint8:
|
||||
tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype)
|
||||
|
||||
self.tensors[name] = TensorInfo(shape=tensor_shape, dtype=dtype, nbytes=tensor_nbytes)
|
||||
# make sure there is at least one tensor before splitting
|
||||
if len(self.tensors[-1]) > 0:
|
||||
if ( # split when over tensor limit
|
||||
self.split_max_tensors != 0
|
||||
and len(self.tensors[-1]) >= self.split_max_tensors
|
||||
) or ( # split when over size limit
|
||||
self.split_max_size != 0
|
||||
and sum(ti.nbytes for ti in self.tensors[-1].values()) + tensor_nbytes > self.split_max_size
|
||||
):
|
||||
self.tensors.append({})
|
||||
|
||||
self.tensors[-1][name] = TensorInfo(shape=tensor_shape, dtype=dtype, nbytes=tensor_nbytes)
|
||||
|
||||
def add_tensor(
|
||||
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
|
||||
@ -264,7 +327,7 @@ class GGUFWriter:
|
||||
self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype=raw_dtype)
|
||||
|
||||
if self.temp_file is None:
|
||||
self.tensors[name].tensor = tensor
|
||||
self.tensors[-1][name].tensor = tensor
|
||||
return
|
||||
|
||||
tensor.tofile(self.temp_file)
|
||||
@ -282,9 +345,24 @@ class GGUFWriter:
|
||||
|
||||
if self.endianess == GGUFEndian.BIG:
|
||||
tensor.byteswap(inplace=True)
|
||||
self.write_padding(self.fout, self.fout.tell())
|
||||
tensor.tofile(self.fout)
|
||||
self.write_padding(self.fout, tensor.nbytes)
|
||||
|
||||
file_id = -1
|
||||
for i, tensors in enumerate(self.tensors):
|
||||
if len(tensors) > 0:
|
||||
file_id = i
|
||||
break
|
||||
|
||||
fout = self.fout[file_id]
|
||||
|
||||
# pop the first tensor info
|
||||
# TODO: cleaner way to get the first key
|
||||
first_tensor_name = [name for name, _ in zip(self.tensors[file_id].keys(), range(1))][0]
|
||||
ti = self.tensors[file_id].pop(first_tensor_name)
|
||||
assert ti.nbytes == tensor.nbytes
|
||||
|
||||
self.write_padding(fout, fout.tell())
|
||||
tensor.tofile(fout)
|
||||
self.write_padding(fout, tensor.nbytes)
|
||||
|
||||
self.state = WriterState.WEIGHTS
|
||||
|
||||
@ -293,31 +371,43 @@ class GGUFWriter:
|
||||
|
||||
assert self.fout is not None
|
||||
|
||||
self.write_padding(self.fout, self.fout.tell())
|
||||
for fout in self.fout:
|
||||
self.write_padding(fout, fout.tell())
|
||||
|
||||
if self.temp_file is None:
|
||||
shard_bar = None
|
||||
bar = None
|
||||
|
||||
if progress:
|
||||
from tqdm import tqdm
|
||||
|
||||
total_bytes = sum(t.nbytes for t in self.tensors.values())
|
||||
total_bytes = sum(ti.nbytes for t in self.tensors for ti in t.values())
|
||||
|
||||
if len(self.fout) > 1:
|
||||
shard_bar = tqdm(desc=f"Shard (0/{len(self.fout)})", total=None, unit="byte", unit_scale=True)
|
||||
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
|
||||
|
||||
# relying on the fact that Python dicts preserve insertion order (since 3.7)
|
||||
for ti in self.tensors.values():
|
||||
assert ti.tensor is not None # can only iterate once over the tensors
|
||||
assert ti.tensor.nbytes == ti.nbytes
|
||||
ti.tensor.tofile(self.fout)
|
||||
if bar is not None:
|
||||
bar.update(ti.nbytes)
|
||||
self.write_padding(self.fout, ti.nbytes)
|
||||
ti.tensor = None
|
||||
for i, (fout, tensors) in enumerate(zip(self.fout, self.tensors)):
|
||||
if shard_bar is not None:
|
||||
shard_bar.set_description(f"Shard ({i + 1}/{len(self.fout)})")
|
||||
total = sum(ti.nbytes for ti in tensors.values())
|
||||
shard_bar.reset(total=(total if total > 0 else None))
|
||||
|
||||
# relying on the fact that Python dicts preserve insertion order (since 3.7)
|
||||
for ti in tensors.values():
|
||||
assert ti.tensor is not None # can only iterate once over the tensors
|
||||
assert ti.tensor.nbytes == ti.nbytes
|
||||
ti.tensor.tofile(fout)
|
||||
if shard_bar is not None:
|
||||
shard_bar.update(ti.nbytes)
|
||||
if bar is not None:
|
||||
bar.update(ti.nbytes)
|
||||
self.write_padding(fout, ti.nbytes)
|
||||
ti.tensor = None
|
||||
else:
|
||||
self.temp_file.seek(0)
|
||||
|
||||
shutil.copyfileobj(self.temp_file, self.fout)
|
||||
shutil.copyfileobj(self.temp_file, self.fout[0 if not self.small_first_shard else 1])
|
||||
self.flush()
|
||||
self.temp_file.close()
|
||||
|
||||
@ -325,11 +415,13 @@ class GGUFWriter:
|
||||
|
||||
def flush(self) -> None:
|
||||
assert self.fout is not None
|
||||
self.fout.flush()
|
||||
for fout in self.fout:
|
||||
fout.flush()
|
||||
|
||||
def close(self) -> None:
|
||||
if self.fout is not None:
|
||||
self.fout.close()
|
||||
for fout in self.fout:
|
||||
fout.close()
|
||||
self.fout = None
|
||||
|
||||
def add_architecture(self) -> None:
|
||||
@ -626,6 +718,13 @@ class GGUFWriter:
|
||||
|
||||
return kv_data
|
||||
|
||||
def _write_packed(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> None:
|
||||
assert self.fout is not None
|
||||
self.fout.write(self._pack(fmt, value, skip_pack_prefix))
|
||||
@staticmethod
|
||||
def format_n_bytes_to_str(num: int) -> str:
|
||||
if num == 0:
|
||||
return "negligible - metadata only"
|
||||
fnum = float(num)
|
||||
for unit in ("", "K", "M", "G"):
|
||||
if abs(fnum) < 1000.0:
|
||||
return f"{fnum:3.1f}{unit}"
|
||||
fnum /= 1000.0
|
||||
return f"{fnum:.1f}T - over 1TB, split recommended"
|
||||
|
Loading…
Reference in New Issue
Block a user