fix ftype

This commit is contained in:
ngxson 2024-07-08 21:55:41 +02:00
parent 84288ff9f7
commit 7a83f200d3

View File

@ -5,19 +5,12 @@ from __future__ import annotations
import logging import logging
import argparse import argparse
import contextlib
import json
import os import os
import re
import sys import sys
import types import types
from enum import IntEnum
from pathlib import Path from pathlib import Path
from hashlib import sha256 from typing import TYPE_CHECKING, Iterable, Iterator
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast
import math
import numpy as np
import torch import torch
if TYPE_CHECKING: if TYPE_CHECKING:
@ -32,22 +25,17 @@ from convert_hf_to_gguf import Model
logger = logging.getLogger("lora-to-gguf") logger = logging.getLogger("lora-to-gguf")
def parse_args() -> argparse.Namespace: def parse_args() -> argparse.Namespace:
all_models = ", ".join([arch for arch in Model._model_classes.keys()])
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Convert a huggingface model to a GGML compatible file") description="Convert a huggingface PEFT LoRA adapter to a GGML compatible file")
parser.add_argument( parser.add_argument(
"--outfile", type=Path, "--outfile", type=Path,
help="path to write to; default: based on input.", help="path to write to; default: based on input.",
) )
parser.add_argument( parser.add_argument(
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16", "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0"], default="f16",
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0",
)
parser.add_argument(
"--arch", type=str,
help=f"Arch of the base model, must be one of: {all_models} (default: LlamaForCausalLM)",
default="LlamaForCausalLM"
) )
parser.add_argument( parser.add_argument(
"--bigendian", action="store_true", "--bigendian", action="store_true",
@ -73,14 +61,13 @@ if __name__ == '__main__':
args = parse_args() args = parse_args()
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
# FIXME: outtype is not working
ftype_map: dict[str, gguf.LlamaFileType] = { ftype_map: dict[str, gguf.LlamaFileType] = {
"f32": gguf.LlamaFileType.ALL_F32, "f32": gguf.LlamaFileType.ALL_F32,
"f16": gguf.LlamaFileType.MOSTLY_F16, "f16": gguf.LlamaFileType.MOSTLY_F16,
"bf16": gguf.LlamaFileType.MOSTLY_BF16, "bf16": gguf.LlamaFileType.MOSTLY_BF16,
"q8_0": gguf.LlamaFileType.MOSTLY_Q8_0, "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
"auto": gguf.LlamaFileType.GUESSED,
} }
ftype = ftype_map[args.outtype]
dir_base_model = args.base dir_base_model = args.base
dir_lora = args.lora_path dir_lora = args.lora_path
@ -110,7 +97,7 @@ if __name__ == '__main__':
logger.error(f"Model {hparams['architectures'][0]} is not supported") logger.error(f"Model {hparams['architectures'][0]} is not supported")
sys.exit(1) sys.exit(1)
model_instance = model_class(dir_base_model, ftype_map[args.outtype], fname_out, args.bigendian, False, False, None) model_instance = model_class(dir_base_model, ftype, fname_out, args.bigendian, False, False, None)
logger.info("Set model parameters") logger.info("Set model parameters")
model_instance.set_gguf_parameters() model_instance.set_gguf_parameters()
@ -140,16 +127,18 @@ if __name__ == '__main__':
# overwrite method # overwrite method
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused del bid # unused
# TODO: This will not take into account tensor transformations
return [(name, data_torch)] return [(name, data_torch)]
# overwrite method # overwrite method
def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool: def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
del name, new_name, bid, n_dims # unused del name, new_name, bid, n_dims # unused
return True return ftype != gguf.LlamaFileType.ALL_F32
model_instance.get_tensors = types.MethodType(get_tensors, model_instance) model_instance.get_tensors = types.MethodType(get_tensors, model_instance)
model_instance.modify_tensors = types.MethodType(modify_tensors, model_instance) model_instance.modify_tensors = types.MethodType(modify_tensors, model_instance)
model_instance.extra_f16_tensors = types.MethodType(extra_f16_tensors, model_instance) model_instance.extra_f16_tensors = types.MethodType(extra_f16_tensors, model_instance)
model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION) model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
logger.info("Exporting model...") logger.info("Exporting model...")
model_instance.write() model_instance.write()