mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-11 13:30:35 +00:00
convert : fix conversion of some BERT embedding models (#6937)
This commit is contained in:
parent
577277ffd2
commit
3055a41805
@ -2482,6 +2482,10 @@ class BertModel(Model):
|
||||
print(f"Can not map tensor {name!r}")
|
||||
sys.exit()
|
||||
|
||||
# convert any unsupported data types to float32
|
||||
if data_torch.dtype not in (torch.float16, torch.float32):
|
||||
data_torch = data_torch.to(torch.float32)
|
||||
|
||||
data = data_torch.squeeze().numpy()
|
||||
n_dims = len(data.shape)
|
||||
new_dtype: type[np.floating[Any]]
|
||||
|
Loading…
Reference in New Issue
Block a user