diff --git a/gguf.py b/gguf.py index 8e2c771ba..88e2bee07 100644 --- a/gguf.py +++ b/gguf.py @@ -61,98 +61,113 @@ class GGUFWriter: def __init__(self, fout: IO): self.fout = fout self.offset_tensor = 0 + self.kv_data = b"" + self.kv_data_count = 0 + self.ti_data = b"" + self.ti_data_count = 0 - def write_header(self, tensor_count: int, metadata_kv_count: int): + def write_header_to_file(self): self.fout.write(struct.pack(" "GGUFWriter": f = open(path, "wb") return cls(f) - def write_key(self, key: str): - self.write_val(key, GGUFValueType.STRING, write_vtype=False) + def add_key(self, key: str): + self.add_val(key, GGUFValueType.STRING, add_vtype=False) - def write_uint8(self, key: str, val: int): - self.write_key(key) - self.write_val(val, GGUFValueType.UINT8) + def add_uint8(self, key: str, val: int): + self.add_key(key) + self.add_val(val, GGUFValueType.UINT8) - def write_int8(self, key: str, val: int): - self.write_key(key) - self.write_val(val, GGUFValueType.INT8) + def add_int8(self, key: str, val: int): + self.add_key(key) + self.add_val(val, GGUFValueType.INT8) - def write_uint16(self, key: str, val: int): - self.write_key(key) - self.write_val(val, GGUFValueType.UINT16) + def add_uint16(self, key: str, val: int): + self.add_key(key) + self.add_val(val, GGUFValueType.UINT16) - def write_int16(self, key: str, val: int): - self.write_key(key) - self.write_val(val, GGUFValueType.INT16) + def add_int16(self, key: str, val: int): + self.add_key(key) + self.add_val(val, GGUFValueType.INT16) - def write_uint32(self, key: str, val: int): - self.write_key(key) - self.write_val(val, GGUFValueType.UINT32) + def add_uint32(self, key: str, val: int): + self.add_key(key) + self.add_val(val, GGUFValueType.UINT32) - def write_int32(self, key: str, val: int): - self.write_key(key) - self.write_val(val, GGUFValueType.INT32) + def add_int32(self, key: str, val: int): + self.add_key(key) + self.add_val(val, GGUFValueType.INT32) - def write_float32(self, key: str, val: float): - self.write_key(key) - self.write_val(val, GGUFValueType.FLOAT32) + def add_float32(self, key: str, val: float): + self.add_key(key) + self.add_val(val, GGUFValueType.FLOAT32) - def write_bool(self, key: str, val: bool): - self.write_key(key) - self.write_val(val, GGUFValueType.BOOL) + def add_bool(self, key: str, val: bool): + self.add_key(key) + self.add_val(val, GGUFValueType.BOOL) - def write_string(self, key: str, val: str): - self.write_key(key) - self.write_val(val, GGUFValueType.STRING) + def add_string(self, key: str, val: str): + self.add_key(key) + self.add_val(val, GGUFValueType.STRING) - def write_array(self, key: str, val: list): + def add_array(self, key: str, val: list): if not isinstance(val, list): raise ValueError("Value must be a list for array type") - self.write_key(key) - self.write_val(val, GGUFValueType.ARRAY) + self.add_key(key) + self.add_val(val, GGUFValueType.ARRAY) - def write_val(self: str, val: Any, vtype: GGUFValueType = None, write_vtype: bool = True): + def add_val(self: str, val: Any, vtype: GGUFValueType = None, add_vtype: bool = True): if vtype is None: vtype = GGUFValueType.get_type(val) - if write_vtype: - self.fout.write(struct.pack(" int: return ((x + n - 1) // n) * n - def write_tensor_info(self, name: str, tensor: np.ndarray): - self.write_key(name) + def add_tensor_info(self, name: str, tensor: np.ndarray): + encoded_name = name.encode("utf8") + self.ti_data += struct.pack("