py : fix scalar-tensor conversion [no ci]

This commit is contained in:
Georgi Gerganov 2024-09-17 13:40:52 +03:00
parent 3453e62bb9
commit 77723ed69e
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -291,8 +291,13 @@ class Model:
bid = int(part)
break
for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)):
data: np.ndarray # type hint
for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)):
data = data_torch.squeeze().numpy()
# if data ends up empty, it means data_torch was a scalar tensor -> restore
if len(data.shape) == 0:
data = data_torch.numpy()
n_dims = len(data.shape)
data_qtype: gguf.GGMLQuantizationType | bool = self.tensor_force_quant(name, new_name, bid, n_dims)