From 847135aaa25ae999060ddb8431f5d529f9244389 Mon Sep 17 00:00:00 2001 From: ngxson Date: Mon, 8 Jul 2024 16:35:27 +0200 Subject: [PATCH 1/6] add convert script --- convert_lora_to_gguf.py | 149 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100755 convert_lora_to_gguf.py diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py new file mode 100755 index 000000000..9a5c7a2c8 --- /dev/null +++ b/convert_lora_to_gguf.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import logging +import argparse +import contextlib +import json +import os +import re +import sys +import types +from enum import IntEnum +from pathlib import Path +from hashlib import sha256 +from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast + +import math +import numpy as np +import torch + +if TYPE_CHECKING: + from torch import Tensor + +if 'NO_LOCAL_GGUF' not in os.environ: + sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) +import gguf + +# reuse model definitions from convert_hf_to_gguf.py +from convert_hf_to_gguf import Model + +logger = logging.getLogger("lora-to-gguf") + +def parse_args() -> argparse.Namespace: + all_models = ", ".join([arch for arch in Model._model_classes.keys()]) + parser = argparse.ArgumentParser( + description="Convert a huggingface model to a GGML compatible file") + parser.add_argument( + "--outfile", type=Path, + help="path to write to; default: based on input.", + ) + parser.add_argument( + "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16", + help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", + ) + parser.add_argument( + "--arch", type=str, + help=f"Arch of the base model, must be one of: {all_models} (default: LlamaForCausalLM)", + default="LlamaForCausalLM" + ) + parser.add_argument( + "--bigendian", action="store_true", + help="model is executed on big endian machine", + ) + parser.add_argument( + "--verbose", action="store_true", + help="increase output verbosity", + ) + parser.add_argument( + "--base", type=Path, required=True, + help="directory containing base model file", + ) + parser.add_argument( + "lora_path", type=Path, + help="directory containing LoRA adapter file", + ) + + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_args() + logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) + + # FIXME: outtype is not working + ftype_map: dict[str, gguf.LlamaFileType] = { + "f32": gguf.LlamaFileType.ALL_F32, + "f16": gguf.LlamaFileType.MOSTLY_F16, + "bf16": gguf.LlamaFileType.MOSTLY_BF16, + "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0, + "auto": gguf.LlamaFileType.GUESSED, + } + + dir_base_model = args.base + dir_lora = args.lora_path + input_json = os.path.join(dir_lora, "adapter_config.json") + input_model = os.path.join(dir_lora, "adapter_model.bin") + if args.outfile is not None: + fname_out = args.outfile + else: + # output in the same directory as the model by default + fname_out = dir_lora / 'ggml-lora.gguf' + + if os.path.exists(input_model): + lora_model = torch.load(input_model, map_location="cpu") + else: + input_model = os.path.join(dir_lora, "adapter_model.safetensors") + # lazy import load_file only if lora is in safetensors format. + from safetensors.torch import load_file + lora_model = load_file(input_model, device="cpu") + + # load base model + logger.info(f"Loading base model: {dir_base_model.name}") + hparams = Model.load_hparams(dir_base_model) + with torch.inference_mode(): + try: + model_class = Model.from_model_architecture(hparams["architectures"][0]) + except NotImplementedError: + logger.error(f"Model {hparams['architectures'][0]} is not supported") + sys.exit(1) + + model_instance = model_class(dir_base_model, ftype_map[args.outtype], fname_out, args.bigendian, False, False, None) + logger.info("Set model parameters") + model_instance.set_gguf_parameters() + + # adapter_config = json.load(input_json) + model_instance.gguf_writer.add_string("training.type", "finetune_lora") + + map_tensors: dict[str, Tensor] = {} + for tensor_name, tensor in lora_model.items(): + orig_name = tensor_name.replace("base_model.model.", "") + orig_name = orig_name.replace(".lora_A.weight", ".weight") + orig_name = orig_name.replace(".lora_B.weight", ".weight") + is_lora_a = ".lora_A.weight" in tensor_name + is_lora_b = ".lora_B.weight" in tensor_name + if not is_lora_a and not is_lora_b: + logger.error(f"Unexpected name '{tensor_name}': Not a lora_A or lora_B tensor") + sys.exit(1) + dest_name = model_instance.map_tensor_name(orig_name) + dest_name = f"{dest_name}.lora_a" if is_lora_a else f"{dest_name}.lora_b" + # logger.info(f"{orig_name} --> {dest_name}") + map_tensors[dest_name] = tensor + + # overwrite method + def get_tensors(self) -> Iterator[tuple[str, Tensor]]: + for name, tensor in map_tensors.items(): + yield (name, tensor) + + # overwrite method + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + return [(name, data_torch)] + + model_instance.get_tensors = types.MethodType(get_tensors, model_instance) + model_instance.modify_tensors = types.MethodType(modify_tensors, model_instance) + model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION) + logger.info("Exporting model...") + model_instance.write() + logger.info(f"Model successfully exported to {fname_out}") From 84288ff9f7e945bb730bb0df069ecf2054ba6076 Mon Sep 17 00:00:00 2001 From: ngxson Date: Mon, 8 Jul 2024 17:05:17 +0200 Subject: [PATCH 2/6] add f16 convert --- convert_lora_to_gguf.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py index 9a5c7a2c8..36ccb73cf 100755 --- a/convert_lora_to_gguf.py +++ b/convert_lora_to_gguf.py @@ -139,10 +139,17 @@ if __name__ == '__main__': # overwrite method def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused return [(name, data_torch)] + # overwrite method + def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool: + del name, new_name, bid, n_dims # unused + return True + model_instance.get_tensors = types.MethodType(get_tensors, model_instance) model_instance.modify_tensors = types.MethodType(modify_tensors, model_instance) + model_instance.extra_f16_tensors = types.MethodType(extra_f16_tensors, model_instance) model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION) logger.info("Exporting model...") model_instance.write() From 7a83f200d353db68fef8458017c7db17b0a303c4 Mon Sep 17 00:00:00 2001 From: ngxson Date: Mon, 8 Jul 2024 21:55:41 +0200 Subject: [PATCH 3/6] fix ftype --- convert_lora_to_gguf.py | 31 ++++++++++--------------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py index 36ccb73cf..861ab1e97 100755 --- a/convert_lora_to_gguf.py +++ b/convert_lora_to_gguf.py @@ -5,19 +5,12 @@ from __future__ import annotations import logging import argparse -import contextlib -import json import os -import re import sys import types -from enum import IntEnum from pathlib import Path -from hashlib import sha256 -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast +from typing import TYPE_CHECKING, Iterable, Iterator -import math -import numpy as np import torch if TYPE_CHECKING: @@ -32,22 +25,17 @@ from convert_hf_to_gguf import Model logger = logging.getLogger("lora-to-gguf") + def parse_args() -> argparse.Namespace: - all_models = ", ".join([arch for arch in Model._model_classes.keys()]) parser = argparse.ArgumentParser( - description="Convert a huggingface model to a GGML compatible file") + description="Convert a huggingface PEFT LoRA adapter to a GGML compatible file") parser.add_argument( "--outfile", type=Path, help="path to write to; default: based on input.", ) parser.add_argument( - "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16", - help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", - ) - parser.add_argument( - "--arch", type=str, - help=f"Arch of the base model, must be one of: {all_models} (default: LlamaForCausalLM)", - default="LlamaForCausalLM" + "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0"], default="f16", + help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0", ) parser.add_argument( "--bigendian", action="store_true", @@ -73,14 +61,13 @@ if __name__ == '__main__': args = parse_args() logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) - # FIXME: outtype is not working ftype_map: dict[str, gguf.LlamaFileType] = { "f32": gguf.LlamaFileType.ALL_F32, "f16": gguf.LlamaFileType.MOSTLY_F16, "bf16": gguf.LlamaFileType.MOSTLY_BF16, "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0, - "auto": gguf.LlamaFileType.GUESSED, } + ftype = ftype_map[args.outtype] dir_base_model = args.base dir_lora = args.lora_path @@ -110,7 +97,7 @@ if __name__ == '__main__': logger.error(f"Model {hparams['architectures'][0]} is not supported") sys.exit(1) - model_instance = model_class(dir_base_model, ftype_map[args.outtype], fname_out, args.bigendian, False, False, None) + model_instance = model_class(dir_base_model, ftype, fname_out, args.bigendian, False, False, None) logger.info("Set model parameters") model_instance.set_gguf_parameters() @@ -140,16 +127,18 @@ if __name__ == '__main__': # overwrite method def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unused + # TODO: This will not take into account tensor transformations return [(name, data_torch)] # overwrite method def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool: del name, new_name, bid, n_dims # unused - return True + return ftype != gguf.LlamaFileType.ALL_F32 model_instance.get_tensors = types.MethodType(get_tensors, model_instance) model_instance.modify_tensors = types.MethodType(modify_tensors, model_instance) model_instance.extra_f16_tensors = types.MethodType(extra_f16_tensors, model_instance) + model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION) logger.info("Exporting model...") model_instance.write() From d52455f2bec45d7e6df8da5b26b91d969ce4580d Mon Sep 17 00:00:00 2001 From: ngxson Date: Mon, 8 Jul 2024 22:00:13 +0200 Subject: [PATCH 4/6] add requirements --- requirements/requirements-convert_lora_to_gguf.txt | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 requirements/requirements-convert_lora_to_gguf.txt diff --git a/requirements/requirements-convert_lora_to_gguf.txt b/requirements/requirements-convert_lora_to_gguf.txt new file mode 100644 index 000000000..5758076c4 --- /dev/null +++ b/requirements/requirements-convert_lora_to_gguf.txt @@ -0,0 +1,2 @@ +-r ./requirements-convert_hf_to_gguf.txt +--extra-index-url https://download.pytorch.org/whl/cpu From 802565ca4327c3dbc02b83ad25ecd4b2bd8253b7 Mon Sep 17 00:00:00 2001 From: ngxson Date: Mon, 8 Jul 2024 22:01:23 +0200 Subject: [PATCH 5/6] fix requirements --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 52456c2e6..9e190ae27 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ -r ./requirements/requirements-convert_hf_to_gguf.txt -r ./requirements/requirements-convert_hf_to_gguf_update.txt -r ./requirements/requirements-convert_llama_ggml_to_gguf.txt +-r ./requirements/requirements-convert_lora_to_gguf.txt From 95b3eb057b0261a48aeadcb1524a1f58d7ef39cc Mon Sep 17 00:00:00 2001 From: ngxson Date: Mon, 8 Jul 2024 22:05:35 +0200 Subject: [PATCH 6/6] fix outfile --- convert_lora_to_gguf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py index 861ab1e97..76c673101 100755 --- a/convert_lora_to_gguf.py +++ b/convert_lora_to_gguf.py @@ -31,7 +31,7 @@ def parse_args() -> argparse.Namespace: description="Convert a huggingface PEFT LoRA adapter to a GGML compatible file") parser.add_argument( "--outfile", type=Path, - help="path to write to; default: based on input.", + help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", ) parser.add_argument( "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0"], default="f16", @@ -77,7 +77,7 @@ if __name__ == '__main__': fname_out = args.outfile else: # output in the same directory as the model by default - fname_out = dir_lora / 'ggml-lora.gguf' + fname_out = dir_lora / 'ggml-lora-{ftype}.gguf' if os.path.exists(input_model): lora_model = torch.load(input_model, map_location="cpu")