diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index e1ac09e02..b51d68307 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -43,17 +43,18 @@ AnyModel = TypeVar("AnyModel", bound="type[Model]") class Model(ABC): _model_classes: dict[str, type[Model]] = {} - def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool): + def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool): self.dir_model = dir_model self.ftype = ftype self.fname_out = fname_out self.is_big_endian = is_big_endian self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE + self.use_temp_file = use_temp_file self.is_safetensors = self._is_model_safetensors() self.num_parts = Model.count_model_parts(self.dir_model, ".safetensors" if self.is_safetensors else ".bin") self.part_names = self._get_part_names() self.hparams = Model.load_hparams(self.dir_model) - self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=False) + self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file) self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"]) @property @@ -2459,6 +2460,7 @@ def parse_args() -> argparse.Namespace: "model", type=Path, help="directory containing model file", ) + parser.add_argument("--use-temp-file", action="store_true", help="use the tempfile library while processing (helpful when running out of memory, process killed)") return parser.parse_args() @@ -2502,7 +2504,7 @@ def main() -> None: with torch.inference_mode(): model_class = Model.from_model_architecture(hparams["architectures"][0]) - model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian) + model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file) print("Set model parameters") model_instance.set_gguf_parameters()