add f16 convert

This commit is contained in:
ngxson 2024-07-08 17:05:17 +02:00
parent 847135aaa2
commit 84288ff9f7

View File

@ -139,10 +139,17 @@ if __name__ == '__main__':
# overwrite method # overwrite method
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
return [(name, data_torch)] return [(name, data_torch)]
# overwrite method
def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
del name, new_name, bid, n_dims # unused
return True
model_instance.get_tensors = types.MethodType(get_tensors, model_instance) model_instance.get_tensors = types.MethodType(get_tensors, model_instance)
model_instance.modify_tensors = types.MethodType(modify_tensors, model_instance) model_instance.modify_tensors = types.MethodType(modify_tensors, model_instance)
model_instance.extra_f16_tensors = types.MethodType(extra_f16_tensors, model_instance)
model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION) model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
logger.info("Exporting model...") logger.info("Exporting model...")
model_instance.write() model_instance.write()