convert : fix conversion of some BERT embedding models (#6937)

This commit is contained in:
Christian Zhou-Zheng 2024-04-29 09:34:41 -04:00 committed by GitHub
parent 577277ffd2
commit 3055a41805
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2482,6 +2482,10 @@ class BertModel(Model):
print(f"Can not map tensor {name!r}") print(f"Can not map tensor {name!r}")
sys.exit() sys.exit()
# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)
data = data_torch.squeeze().numpy() data = data_torch.squeeze().numpy()
n_dims = len(data.shape) n_dims = len(data.shape)
new_dtype: type[np.floating[Any]] new_dtype: type[np.floating[Any]]