gguf-py : fix dtype check (#6045)

This commit is contained in:
Georgi Gerganov 2024-03-14 13:32:14 +02:00
parent 15a333260a
commit 77178eedc8
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -204,7 +204,7 @@ class GGUFWriter:
for i in range(n_dims): for i in range(n_dims):
self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i]) self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i])
if raw_dtype is None: if raw_dtype is None:
if tensor_shape == np.float32: if tensor_dtype == np.float32:
dtype = GGMLQuantizationType.F32 dtype = GGMLQuantizationType.F32
elif tensor_dtype == np.float16: elif tensor_dtype == np.float16:
dtype = GGMLQuantizationType.F16 dtype = GGMLQuantizationType.F16