gguf : fix writing gguf arrays

This commit is contained in:
M. Yusuf Sarıgöz 2023-07-29 12:25:43 +03:00
parent d54f53ca51
commit ea5f9ad2ca

10
gguf.py
View File

@ -122,10 +122,11 @@ class GGUFWriter:
self.write_key(key) self.write_key(key)
self.write_val(val, GGUFValueType.ARRAY) self.write_val(val, GGUFValueType.ARRAY)
def write_val(self: str, val: Any, vtype: GGUFValueType = None): def write_val(self: str, val: Any, vtype: GGUFValueType = None, write_vtype: bool = True):
if vtype is None: if vtype is None:
vtype = GGUFValueType.get_type(val) vtype = GGUFValueType.get_type(val)
if write_vtype:
self.fout.write(struct.pack("<I", vtype)) self.fout.write(struct.pack("<I", vtype))
if vtype == GGUFValueType.UINT8: if vtype == GGUFValueType.UINT8:
@ -150,8 +151,10 @@ class GGUFWriter:
self.fout.write(encoded_val) self.fout.write(encoded_val)
elif vtype == GGUFValueType.ARRAY: elif vtype == GGUFValueType.ARRAY:
self.fout.write(struct.pack("<I", len(val))) self.fout.write(struct.pack("<I", len(val)))
# TODO: verify that all elements are of the same type
self.fout.write(struct.pack("<I", GGUFValueType.get_type(val[0])))
for item in val: for item in val:
self.write_val(item) self.write_val(item, write_vtype=False)
else: else:
raise ValueError("Invalid GGUF metadata value type") raise ValueError("Invalid GGUF metadata value type")
@ -177,8 +180,7 @@ class GGUFWriter:
self.tensors.append(tensor) self.tensors.append(tensor)
def write_tensors(self): def write_tensors(self):
offset_data = GGUFWriter.ggml_pad(self.fout.tell(), constants.GGUF_DEFAULT_ALIGNMENT) pad = GGUFWriter.ggml_pad(self.fout.tell(), constants.GGUF_DEFAULT_ALIGNMENT) - self.fout.tell()
pad = offset_data - self.fout.tell()
if pad != 0: if pad != 0:
self.fout.write(bytes([0] * pad)) self.fout.write(bytes([0] * pad))