mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 03:14:35 +00:00
convert : fix conversion of some BERT embedding models (#6937)
This commit is contained in:
parent
577277ffd2
commit
3055a41805
@ -2482,6 +2482,10 @@ class BertModel(Model):
|
|||||||
print(f"Can not map tensor {name!r}")
|
print(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
|
# convert any unsupported data types to float32
|
||||||
|
if data_torch.dtype not in (torch.float16, torch.float32):
|
||||||
|
data_torch = data_torch.to(torch.float32)
|
||||||
|
|
||||||
data = data_torch.squeeze().numpy()
|
data = data_torch.squeeze().numpy()
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
new_dtype: type[np.floating[Any]]
|
new_dtype: type[np.floating[Any]]
|
||||||
|
Loading…
Reference in New Issue
Block a user